In [3]:
import os
import re
import json
import tqdm
import torch
from torch.nn.utils.rnn import pad_sequence
import numpy as np
import ray
import openai
import ast
import time
import pickle

os.environ["CUDA_VISIBLE_DEVICES"] = "0"

torch.backends.cuda.matmul.allow_tf32 = True

from transformers import LlamaTokenizer, LlamaForCausalLM

In [4]:
class RedundantParenthesesRemover(ast.NodeTransformer):
    def visit_Expr(self, node):
        self.generic_visit(node)
        if isinstance(node.value, ast.BinOp):
            return node.value
        return node

def remove_redundant_parentheses(expression):
    # Parse the expression
    parsed_expression = ast.parse(expression)

    # Remove redundant parentheses
    transformer = RedundantParenthesesRemover()
    transformed_expression = transformer.visit(parsed_expression)

    # Convert the transformed expression back to a string
    simplified_expression = ast.unparse(transformed_expression)

    return simplified_expression
    
def divide(a, b):
    return f"({a} / {b})"

def subtract(a, b):
    return f"({a} - {b})"

def multiply(a, b):
    return f"({a} * {b})"

def add(a, b):
    return f"({a} + {b})"

def exp(a, b):
    return f"({a} ** {b})"

def greater(a, b):
    return f"({a} > {b})"

def translate_expr(expr):
    if "table" in expr:
        return expr
    
    # replace const_m1
    expr = re.sub(r'const_m1', r'-1', expr)

    # change % to / 100
    expr = re.sub(r'([0-9]*\.?[0-9]+)%', r'divide(\1 , 100)', expr)
    
    expr = re.sub(r'const_([0-9]*\.?[0-9]+)', r'\1', expr)
    try:
        new_expr = eval(expr)
        new_expr = remove_redundant_parentheses(new_expr)
    except Exception as e:
        print(e, expr)
        new_expr = expr
    
    return new_expr

def convert_to_markdown(data):
    markdown = "|"
    
    # Add table headers
    for header in data[0]:
        markdown += header + "|"
    markdown += "\n|"
    
    # Add table header separators
    for _ in data[0]:
        markdown += "---|"
    markdown += "\n"
    
    # Add table rows
    for row in data[1:]:
        markdown += "|"
        for cell in row:
            markdown += cell + "|"
        markdown += "\n"
        
    return markdown

def extract_answer(response):
    # extract content inside Calculate()
    matches = re.findall(r"Calculate\(([\(\)0-9 ><,\.\/\+\-\*]*)\)", response)
    if len(matches) == 0:
        if "Yes" in response:
            return "Yes"
        elif "No" in response:
            return "No"
        else:
            return ""
    else:
        output = matches[0].replace(",", "")
        return output
    
def if_exec_correct(t_prog, g_prog):
    try:
        t_exec = eval(t_prog)

        if type(t_exec) == bool and g_prog in ["Yes", "No"]:
            t_exec = "Yes" if t_exec else "No"

            if t_exec == g_prog:
                return True
            
        g_exec = eval(g_prog)

        if t_exec == g_exec:
            return True
        elif t_exec * 100 == g_exec:
            return True
        elif t_exec * 100 == -g_exec:
            return True
        elif t_exec == g_exec * 100:
            return True
        elif t_exec == -g_exec * 100:
            return True
        elif t_exec * 1000000 == g_exec:
            return True
        elif t_exec * 1000000 == -g_exec:
            return True
        elif t_exec == g_exec * 1000000:
            return True
        elif t_exec == -g_exec * 1000000:
            return True
        elif t_exec == -g_exec:
            return True
    except:
        return False

    return False

In [5]:
tokenizer = LlamaTokenizer.from_pretrained("decapoda-research/llama-7b-hf")

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'LLaMATokenizer'. 
The class this function is called from is 'LlamaTokenizer'.


In [6]:
# state_dict = torch.load("./llama-65b/checkpoint-500/hf_llama.pth", map_location="cpu")
# model = LlamaForCausalLM.from_pretrained("decapoda-research/llama-65b-hf", device_map="auto", torch_dtype=torch.bfloat16)
# model.load_state_dict(state_dict)

In [7]:
system_prompt = (
                "You need to answer the user's question in the ### Question ### section.\n" \
                "You need to provide the answer in the format 'Calculate(a + b)', where the expression needs to be python excutable." \
                # "You can calculate the average of a column by using the function 'Average(table_column_name)'.\n" \
                # "Similarly, you can calculate the sum, the maximum, the minimum, the count of a column by using the functions "\
                # "'Sum(table_column_name)', 'Max(table_column_name)', 'Min(table_column_name)', 'Count(table_column_name)' respectively.\n" \
                # "You only use the table's column name inside those operations\n" \
                "For example, if the question is 'What is the sum of 1 + 2?', you need to answer 'Calc(1 + 2)'." \
                "if the question is 'Is 123 greater than 231?', you need to answer 'Calc(123 > 231)'." \
                # "|Age|\n|---|\n|12|\n|15|\n|16|\n\n What is the average age? The answer is 'Calculate(Average(Age))'" \
                "DO NOT give anything else other than'Calculate()'."
                )

In [8]:
filepath = "../FinQA/dataset/test.json"

with open(filepath) as f:
    data = json.load(f)

programs = []
translated_programs = []
answers = []

for item in tqdm.tqdm(data):
    table_md = convert_to_markdown(item["table_ori"])
    question = item["qa"]["question"]
    
    pre_text = "\n".join(item["pre_text"])
    post_text = "\n".join(item["post_text"])

    programs.append(item["qa"]["program_re"])
    translated_programs.append(translate_expr(item["qa"]["program_re"]))
    answers.append(item["qa"]["answer"])

100%|██████████| 1147/1147 [00:00<00:00, 14194.32it/s]


In [9]:
# with open("./results/llama65b-finetune-300.json") as f:
#     results = json.load(f)

In [18]:
responses = results["responses"]

In [17]:
gen_programs = []

for response in responses:
    gen_programs.append(extract_answer(response))

In [18]:
# calculate accuracy
prog_correct = 0
exec_correct = 0

wrong_indices = []

for i in range(len(translated_programs)):
    t_prog = translated_programs[i].replace(" ", "")
    g_prog = gen_programs[i].replace(" ", "")

    if t_prog == g_prog:
        prog_correct += 1

    if if_exec_correct(t_prog, g_prog):
        exec_correct += 1
    else:
        wrong_indices.append(i)
        print(i, t_prog, "|", g_prog)

5 (92710000-86842000)/86842000 | (100690000-92710000)/92710000
13 281.09>286.22 | 281.09>228.97
15 table_average(netchangefortheyear,none) | (4648+3625+1620)/3
18 15.3/(139549/1000) | 15.3/139549
19 9.7+10.2+(25.0+24.0) | 25.0+9.7
20 (772-843)/843 | (843-772)/772
22 (101.88-93.21)/93.21 | (101.54-101.88)/101.88
29 3500/3081 | 3081/136
32 table_max(provisionforincometaxes,none) | 
33 3087/55687 | 0.51/(2.5+0.51)
34 61912-367 | 367+51.0
35 168-56 | 168+56
38 4.6+5.5+2 | 166+166+166
44 (950.4+(957.4+769.1))/3 | (950.4+(957.4+769.1)+3)/2
47 (559.7-515.2)/559.7 | (765.2-559.7)/559.7
48 10.4/61.4*100 | 10354/61.4
53 3/401362 | 3.0*1000/401362
57 10.8/112.7 | 112.7/10.8
59 411636+439000 | (556000+(411636+439000)+3)/2
61 294/30 | -23/30
62 1211143*308.1 | 1217121*308.1
63 155+-141+4 | 38-155
67 49447/59801 | 10354/59801
75 11.02/9.52/9.52 | (11.02-9.52)/9.52
78 78227+81808 | 
79 207.62/97.13 | 207.62-97.13
80 207.62-97.13 | 207.62-107.76
81 37/1.4 | 37/1000000
82 102400/15790 | 102400/27524
83

In [19]:
exec_correct / len(translated_programs), prog_correct / len(translated_programs)

(0.6451612903225806, 0.5719267654751525)

: 

In [11]:
model = LlamaForCausalLM.from_pretrained("decapoda-research/llama-7b-hf", device_map="auto", torch_dtype=torch.bfloat16)

model_cpu = LlamaForCausalLM.from_pretrained("./llama-7b/checkpoint-1122", torch_dtype=torch.bfloat16)

model.load_state_dict(model_cpu.state_dict())

Loading checkpoint shards:   0%|          | 0/33 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

<All keys matched successfully>

In [12]:
# %%
tokenizer.pad_token_id = tokenizer.unk_token_id

# %%
all_input_ids = []

for item in tqdm.tqdm(data):
    table_md = convert_to_markdown(item["table_ori"])
    question = item["qa"]["question"]
    
    pre_text = "\n".join(item["pre_text"])
    post_text = "\n".join(item["post_text"])
    
    context = f"{pre_text}\n\n{table_md}\n\n{post_text}\n\n"
    user_prompt = f"### Context ###\n\n{context}### Question ###\n\n{question}"
    input_prompt = f"### Instruction ###\n\n{system_prompt}\n\n{user_prompt}\n\n### Answer ###\n\n"

    input_ids = tokenizer.encode(input_prompt, add_special_tokens=False)

    input_ids_length = len(input_ids) 
    max_seq_length = 1984 
    while input_ids_length > max_seq_length:
        # truncate the first input_ids_length - max_seq_length tokens
        context = context.split(" ")[input_ids_length - max_seq_length:]
        context = " ".join(context)
        # recreate the input_text
        user_prompt = f"### Context ###\n\n{context}### Question ###\n\n{question}"
        input_prompt = f"### Instruction ###\n\n{system_prompt}\n\n{user_prompt}\n\n### Answer ###\n\n"

        input_ids = tokenizer.encode(input_prompt, add_special_tokens=False)

        input_ids_length = len(input_ids)

    input_ids = torch.tensor(input_ids)

    all_input_ids.append(input_ids)

100%|██████████| 1147/1147 [00:04<00:00, 244.24it/s]


In [15]:
# %%
batch_size = 4

responses = []

for i in tqdm.tqdm(range(0, len(all_input_ids), batch_size)):
    batch_input_ids = all_input_ids[i:i+batch_size]
    batch_input_ids = [item.flip(0) for item in batch_input_ids]
    batch_input_ids = pad_sequence(batch_input_ids, 
                                   batch_first=True,
                                   padding_value=tokenizer.pad_token_id).to(model.device)
    batch_input_ids = batch_input_ids.flip(1)

    attention_mask = (batch_input_ids != tokenizer.pad_token_id).bool().to(model.device)

    output = model.generate(batch_input_ids,
                            attention_mask=attention_mask,
                            do_sample=True,
                            top_p=0.9, 
                            temperature=0.9, 
                            max_new_tokens=128, 
                            eos_token_id=tokenizer.eos_token_id, 
                            pad_token_id=tokenizer.eos_token_id, 
                            early_stopping=True)

    for j in range(output.shape[0]):
        input_ids_length = batch_input_ids.shape[1]
        response = tokenizer.decode(output[j][input_ids_length:], skip_special_tokens=True)

        responses.append(response)

100%|██████████| 287/287 [08:33<00:00,  1.79s/it]


In [16]:
result = {
            "responses": responses,
            "model_name": "llama-7b-finetune-1122",
          }

with open("./results/llama7b-finetune-1122.json", "w") as f:
    json.dump(result, f, indent=4)


In [28]:
# check llama_funetune pkl

with open("./llama_finetune_dataset.pkl", "rb") as f:
    dataset = pickle.load(f)

In [31]:
tokenizer.decode(dataset[0]['input_ids'])

"### Instruction ###\n\nYou need to answer the user's question in the ### Question ### section.\nYou need to provide the answer in the format 'Calculate(a + b)', where the expression needs to be python excutable.For example, if the question is 'What is the sum of 1 + 2?', you need to answer 'Calculate(1 + 2)'.if the question is 'Is 123 greater than 231?', you need to answer 'Calculate(123 > 231)'.DO NOT give anything else other than'Calculate()'.\n\n### Context ###\n\ninterest rate to a variable interest rate based on the three-month libor plus 2.05% ( 2.05 % ) ( 2.34% ( 2.34 % ) as of october 31 , 2009 ) .\nif libor changes by 100 basis points , our annual interest expense would change by $ 3.8 million .\nforeign currency exposure as more fully described in note 2i .\nin the notes to consolidated financial statements contained in item 8 of this annual report on form 10-k , we regularly hedge our non-u.s .\ndollar-based exposures by entering into forward foreign currency exchange contr