# llama3

In [26]:
import transformers
import torch

model_id = "meta-llama/Meta-Llama-3-8B-Instruct"

pipeline = transformers.pipeline(
    "text-generation",
    model=model_id,
    model_kwargs={"torch_dtype": torch.bfloat16},
    device=2,
)


Loading checkpoint shards: 100%|██████████| 4/4 [14:46<00:00, 221.52s/it]
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [27]:
import os
from json import loads
from tqdm import tqdm
import numpy as np

In [34]:
template_context = """
Column Names are limited to the following:
name, description, team, type, age, location, year, city, rank, status, state, category,
weight, code, club, artist, result, position, country, notes, class, company, album, symbol,
address, duration, format, county, day, gender, industry, language, sex, product, jockey,
region, area, service, teamName, order, isbn, fileSize, grades, publisher, plays, origin,
elevation, affiliation, component, owner, genre,  manufacturer, brand, family, credit, depth,
classification, collection, species, command, nationality, currency, range, affiliate,
birthDate, ranking, capacity, birthPlace, person, creator, operator, religion, education,
requirement, director, sales, continent, organisation
Do not use any column names aside from these.

Output must be in valid JSON like the following example {"colnames" : ["col1", "col2"]}

Given the following relational table:
"""

tabledir = "K4" # Note that it doesn't end with '/'
filenames = os.listdir(tabledir)
real_cols = []
pred_cols = []
#i = 0
for filename in tqdm(filenames[0:100]):
    with open(tabledir + '/' + filename) as f:
        linelist = f.readlines()
        colnames = linelist[0][:-1].split(',')
        
        lines = ''.join(linelist[1:21])#.replace(',',';')
        messages = [
            {"role": "system", "content": """
        You are a database expert who can make general predictions for missing column values in database tables, and the predicted column names are within the required candidate set. All output must be in valid JSON. Don't add explanation beyond the JSON.
        Always answer as helpfully as possible, while being safe.  Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature. If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.
        """},
            {"role": "user", "content": f"{template_context} \n {lines} Guess the column names for the whole table. There are only {len(colnames)} columns in the table. It is possible for multiple columns to have the same name.\n"},
        ]

        prompt = pipeline.tokenizer.apply_chat_template(
                messages, 
                tokenize=False, 
                add_generation_prompt=True
        )

        terminators = [
            pipeline.tokenizer.eos_token_id,
            pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>")
        ]

        outputs = pipeline(
            prompt,
            max_new_tokens=256,
            eos_token_id=terminators,
            pad_token_id=pipeline.tokenizer.eos_token_id,
            do_sample=True,
            temperature=0.6,
            top_p=0.9,
        )
        # print(outputs[0]["generated_text"][len(prompt):])
        jslist = loads(outputs[0]["generated_text"][len(prompt):])
        if len(jslist["colnames"]) != len(colnames):
            print(jslist["colnames"], colnames)
        else:
            real_cols += colnames
            pred_cols += jslist["colnames"]

  0%|          | 0/100 [00:00<?, ?it/s]

 10%|█         | 10/100 [00:04<00:44,  2.04it/s]

['name', 'description', 'year', 'status'] ['notes']


 12%|█▏        | 12/100 [00:05<00:45,  1.92it/s]

['result', 'team', 'name'] ['name']


 13%|█▎        | 13/100 [00:06<00:51,  1.70it/s]

['FR Doc.', 'RIN', 'Docket No.'] ['type']


 16%|█▌        | 16/100 [00:07<00:45,  1.86it/s]

['type', 'zip code'] ['county']


 26%|██▌       | 26/100 [00:12<00:39,  1.89it/s]

['Docket No.', 'RIN', 'FR Doc.'] ['type']


 31%|███       | 31/100 [00:14<00:32,  2.11it/s]

['city', 'country'] ['location']


 35%|███▌      | 35/100 [00:16<00:28,  2.29it/s]

['location', 'distance'] ['city']


 38%|███▊      | 38/100 [00:17<00:33,  1.86it/s]

['team', 'name', 'result', 'category'] ['class']


 41%|████      | 41/100 [00:19<00:35,  1.66it/s]

['CC', 'REMI', 'LAPS', 'FP', 'Effective date'] ['code', 'description']


 44%|████▍     | 44/100 [00:22<00:44,  1.25it/s]

['description', 'option', 'behavior', 'log', 'action', 'file', 'destination', 'update'] ['description']


 47%|████▋     | 47/100 [00:23<00:31,  1.67it/s]

['name', 'description'] ['position']


 49%|████▉     | 49/100 [00:24<00:26,  1.95it/s]

['name', 'description'] ['name']


 51%|█████     | 51/100 [00:25<00:27,  1.80it/s]

['category', 'description', 'company', 'industry'] ['industry']


 53%|█████▎    | 53/100 [00:26<00:26,  1.78it/s]

['country', 'region', 'of', 'manufacture'] ['language']


 55%|█████▌    | 55/100 [00:27<00:26,  1.73it/s]

['1st', '2nd', '3rd'] ['person']


 62%|██████▏   | 62/100 [00:33<00:48,  1.27s/it]

['type', 'category', 'team', 'status', 'rank', 'year', 'location', 'teamName', 'city', 'state', 'country', 'industry', 'genre', 'company', 'product', 'service', 'album', 'isbn', 'publisher', 'origin', 'birthDate', 'birthPlace', 'education', 'director', 'sales', 'continent', 'organisation'] ['description']


 68%|██████▊   | 68/100 [00:36<00:17,  1.81it/s]

['description', 'notes'] ['description']


 91%|█████████ | 91/100 [00:46<00:04,  2.17it/s]

['city', 'year'] ['team']


100%|██████████| 100/100 [00:50<00:00,  1.97it/s]

['RIN', 'FR Doc.'] ['type']





In [35]:
with open(f'{tabledir}_true.npy', 'wb') as f:
    np.save(f, np.array(real_cols, dtype='<U14'))
with open(f'{tabledir}_pred.npy', 'wb') as f:
    np.save(f, np.array(pred_cols, dtype='<U14'))
"""correct = 0
total = 0
for real,pred in zip(real_cols, pred_cols):
    for r,p in zip(real,pred):
        total += 1
        if r == (p[0].lower() + p[1:]):
            correct += 1
print(f'Accuracy: {correct/total}')"""

"correct = 0\ntotal = 0\nfor real,pred in zip(real_cols, pred_cols):\n    for r,p in zip(real,pred):\n        total += 1\n        if r == (p[0].lower() + p[1:]):\n            correct += 1\nprint(f'Accuracy: {correct/total}')"

In [36]:
print(len(pred_cols))
print(len(real_cols))
print(real_cols[0], pred_cols[0])

127
127
name name
