In [1]:
import os
import json
from tqdm import tqdm
import numpy as np
import pandas as pd
from npy_postprocess import canonical_header
from f1_llama import report_gen
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline

model_id = "meta-llama/Llama-2-7b-chat-hf"
model_size = "7B_adapter"

# Custom parameters
column_type = "single"                             # single/multi, like for single column type, We only read tables with only one column
directory_indexs = ["K0", "K1", "K2", "K3", "K4"] # directory used for iterate
file_num = 100                                    # for each directory, read file_num of column_type files
rows_num = 20                                     # for each table, how many rows we need read for prompt
max_new_tokens = 4096                             # max token for LLM
# ignore_mismatch = False                           # if throw out mismatch
gpu_device = "auto"                               # 8B needs specify the index of GPU like 1 or 2, 70B use "auto"
enable_adapter = True
adapter_id = "sadpineapple/llama2-7b-chat-adapter"

model = AutoModelForCausalLM.from_pretrained(model_id,
                                             device_map=gpu_device,
                                             trust_remote_code=False,
                                             revision="main")
if enable_adapter:
    model.load_adapter(adapter_id)

tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)

template_context = """
Column Names are limited to the following:
name, description, team, type, age, location, year, city, rank, status, state, category,
weight, code, club, artist, result, position, country, notes, class, company, album, symbol,
address, duration, format, county, day, gender, industry, language, sex, product, jockey,
region, area, service, teamName, order, isbn, fileSize, grades, publisher, plays, origin,
elevation, affiliation, component, owner, genre,  manufacturer, brand, family, credit, depth,
classification, collection, species, command, nationality, currency, range, affiliate,
birthDate, ranking, capacity, birthPlace, person, creator, operator, religion, education,
requirement, director, sales, continent, organisation
Do not use any column names aside from these.
"""

single_context = template_context + """
Given the following column values in a relational table:"""

sysprompt = {"role" : "system", "content" : """
    Column Names are limited to the following:
    name, description, team, type, age, location, year, city, rank, status, state, category,
    weight, code, club, artist, result, position, country, notes, class, company, album, symbol,
    address, duration, format, county, day, gender, industry, language, sex, product, jockey,
    region, area, service, teamName, order, isbn, fileSize, grades, publisher, plays, origin,
    elevation, affiliation, component, owner, genre,  manufacturer, brand, family, credit, depth,
    classification, collection, species, command, nationality, currency, range, affiliate,
    birthDate, ranking, capacity, birthPlace, person, creator, operator, religion, education,
    requirement, director, sales, continent, organisation
    Do not use any column names aside from these.
    
    No pre-amble. Answer is in the following format: answer
    """}

trues = []
preds = []

true_path = "npy/trues/"
pred_path = "npy/preds/" 
if not os.path.exists(true_path):
    os.makedirs(true_path)
if not os.path.exists(pred_path):
    os.makedirs(pred_path)

for tabledir in tqdm(directory_indexs):
    filenames = os.listdir(tabledir)
    real_cols = []
    pred_cols = []
    mismatch = 0
    error_num = 0
    #file_cnt = 0
    for filename in filenames[0:file_num]: #moved file_num here
        fullpath = tabledir + '/' + filename
        with open(fullpath) as f:
            pipe = pipeline(
                "text-generation",
                model=model,
                tokenizer=tokenizer,
                max_new_tokens=max_new_tokens,
                pad_token_id=tokenizer.eos_token_id,
                do_sample=True,
                temperature=0.6,
                top_p=0.9,
                top_k=40,
                repetition_penalty=1.1
            )

            linelist = f.readlines()
            colnames = linelist[0][:-1].split(',')
            real_cols += colnames
            
            
            df = pd.read_csv(fullpath).astype(str)
            for col in df.columns:
                content = f"{single_context} {', '.join(df[col][0:rows_num])}\nGuess the column name"
                messages = [sysprompt, {"role": "user", "content": content},]
                cpred = pipe(messages)[0]['generated_text'][-1]['content'].replace('\n', ' ').split('.')[0].split()[-1]
                pred_cols.append(canonical_header(cpred))


    print(len(pred_cols), len(real_cols), mismatch, error_num)
    with open(f'npy/trues/{column_type}_{tabledir}_true.npy', 'wb') as f:
        np.save(f, np.array(real_cols, dtype='<U14'))
    with open(f'npy/preds/{column_type}_{tabledir}_pred.npy', 'wb') as f:
        np.save(f, np.array(pred_cols, dtype='<U14'))
    trues += real_cols
    preds += pred_cols
                

results_path = "results/" 
if not os.path.exists(results_path):
    os.makedirs(results_path)
print(len(preds), len(trues))
overall, report = report_gen(preds,trues)
with open(f"results/{column_type}_overall_{model_size}.json","w") as f:
    json.dump(overall,f)
report.to_csv(f'results/{column_type}_report_{model_size}.csv', index=False)
print(f"Results are successfully written into results/{column_type}_overall_{model_size}.json and results/{column_type}_report_{model_size}.csv")




Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

adapter_config.json:   0%|          | 0.00/654 [00:00<?, ?B/s]

adapter_model.safetensors:   0%|          | 0.00/16.8M [00:00<?, ?B/s]

 20%|██        | 1/5 [02:47<11:11, 167.77s/it]

146 146 0 0


 40%|████      | 2/5 [05:19<07:54, 158.14s/it]

143 143 0 0


 60%|██████    | 3/5 [08:09<05:27, 163.63s/it]

164 164 0 0


 80%|████████  | 4/5 [10:43<02:39, 160.00s/it]

151 151 0 0


100%|██████████| 5/5 [13:19<00:00, 159.94s/it]

147 147 0 0
751 751





Results are successfully written into results/single_overall_7B_adapter.json and results/single_report_7B_adapter.csv
