# llama3

In [3]:
# model type
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
model_size = "8B"

In [72]:
# Custom parameters
column_type = "single"                            # single/multi, like for single column type, We only read tables with only one column
directory_indexs = ["K0", "K1", "K2", "K3", "K4"] # directory used for iterate
file_num = 100                                    # for each directory, read file_num of column_type files
rows_num = 20                                     # for each table, how many rows we need read for prompt
max_new_tokens = 4096                             # max token for LLM
ignore_mismatch = False                           # if throw out mismatch
gpu_device = 2                                    # specify the index of GPU

In [2]:
import transformers
import torch

pipeline = transformers.pipeline(
    "text-generation",
    model=model_id,
    model_kwargs={"torch_dtype": torch.bfloat16},
    device=gpu_device,
)


  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|██████████| 4/4 [16:07<00:00, 241.81s/it]
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [63]:
import os
import json
from json import loads, JSONDecodeError
from tqdm import tqdm
import numpy as np
# from importlib import reload
# import sys
from npy_postprocess import canonical_header
from f1_llama import report_gen
# print(sys.modules['f1_llama'])
# reload(sys.modules['f1_llama'])

In [73]:
template_context = """
Column Names are limited to the following:
name, description, team, type, age, location, year, city, rank, status, state, category,
weight, code, club, artist, result, position, country, notes, class, company, album, symbol,
address, duration, format, county, day, gender, industry, language, sex, product, jockey,
region, area, service, teamName, order, isbn, fileSize, grades, publisher, plays, origin,
elevation, affiliation, component, owner, genre,  manufacturer, brand, family, credit, depth,
classification, collection, species, command, nationality, currency, range, affiliate,
birthDate, ranking, capacity, birthPlace, person, creator, operator, religion, education,
requirement, director, sales, continent, organisation
Do not use any column names aside from these.

Output must be in valid JSON like the following example {"colnames" : ["col1", "col2"]}

Given the following relational table:
"""

trues = []
preds = []

true_path = "npy/trues/"
pred_path = "npy/preds/" 
if not os.path.exists(true_path):
    os.makedirs(true_path)
if not os.path.exists(pred_path):
    os.makedirs(pred_path)

for tabledir in tqdm(directory_indexs):
    filenames = os.listdir(tabledir)
    real_cols = []
    pred_cols = []
    mismatch = 0
    error_num = 0
    file_cnt = 0
    for filename in filenames:
        with open(tabledir + '/' + filename) as f:
            linelist = f.readlines()
            colnames = linelist[0][:-1].split(',')
            if column_type == "single" and len(colnames) != 1:
                continue
            if column_type == "multi" and len(colnames) <= 1:
                continue
            lines = ''.join(linelist[1:rows_num+1])#.replace(',',';')
            content = f"{template_context} \n {lines} Guess the column names for the whole table. There are only {len(colnames)} columns in the table."
            if column_type == "multi":
                content += "It is possible for multiple columns to have the same name.\n"
            messages = [
                {"role": "system", "content": """
            You are a database expert who can make general predictions for missing column values in database tables, and the predicted column names are within the required candidate set. All output must be in valid JSON. Don't add explanation beyond the JSON.
            Always answer as helpfully as possible, while being safe.  Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.
            """},
                {"role": "user", "content": content},
            ]

            prompt = pipeline.tokenizer.apply_chat_template(
                    messages, 
                    tokenize=False, 
                    add_generation_prompt=True
            )

            terminators = [
                pipeline.tokenizer.eos_token_id,
                pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>")
            ]

            outputs = pipeline(
                prompt,
                max_new_tokens=max_new_tokens,
                eos_token_id=terminators,
                pad_token_id=pipeline.tokenizer.eos_token_id,
                do_sample=True,
                temperature=0.6,
                top_p=0.9,
            )
            try:
                jslist = loads(outputs[0]["generated_text"][len(prompt):])
                
                if len(jslist["colnames"]) == len(colnames):
                    real_cols += colnames
                    items = jslist["colnames"]
                    for item in items:
                        pred_cols.append(canonical_header(item))
                else:
                    mismatch += 1
                    if not ignore_mismatch:
                        real_cols += colnames
                        pred_cols += ["???"] * len(colnames)
                file_cnt += 1
            except JSONDecodeError as e:
                # below two lines for debug
                # print("json load failed:", filename, outputs[0]["generated_text"][len(prompt):])
                # print("error:", str(e))
                error_num += 1
        if file_cnt == file_num:
            break

    print(len(pred_cols), len(real_cols), mismatch, error_num)
    with open(f'npy/trues/{column_type}_{tabledir}_true.npy', 'wb') as f:
        np.save(f, np.array(real_cols, dtype='<U14'))
    with open(f'npy/preds/{column_type}_{tabledir}_pred.npy', 'wb') as f:
        np.save(f, np.array(pred_cols, dtype='<U14'))
    trues += real_cols
    preds += pred_cols
                

 20%|██        | 1/5 [00:39<02:38, 39.67s/it]

100 100 7 0


 40%|████      | 2/5 [01:30<02:17, 45.95s/it]

100 100 6 1


 60%|██████    | 3/5 [02:13<01:30, 45.01s/it]

100 100 10 0


 80%|████████  | 4/5 [02:55<00:43, 43.86s/it]

100 100 9 1


100%|██████████| 5/5 [03:37<00:00, 43.49s/it]

100 100 11 1





In [71]:
results_path = "results/" 
if not os.path.exists(results_path):
    os.makedirs(results_path)
print(len(preds), len(trues))
overall, report = report_gen(preds,trues)
with open(f"results/{column_type}_overall_{model_size}.json","w") as f:
    json.dump(overall,f)
report.to_csv(f'results/{column_type}_report_{model_size}.csv', index=False)
print(f"Results are successfully written into results/{column_type}_overall_{model_size}.json and results/{column_type}_report_{model_size}.csv")
# """correct = 0
# total = 0
# for real,pred in zip(real_cols, pred_cols):
#     for r,p in zip(real,pred):
#         total += 1
#         if r == (p[0].lower() + p[1:]):
#             correct += 1
# print(f'Accuracy: {correct/total}')"""

1181 1181
Results are successfully written into results/multi_overall_8B.json and results/multi_report_8B.csv


  df_report= pd.concat([df_report, pd.DataFrame(report[t],index=[0])], ignore_index=True)
