In [1]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from tqdm import tqdm
from datasets import load_from_disk
from functools import partial
import time
import logging
import os
from torch.utils.data import DataLoader
from datetime import timedelta
import random
random.seed(42)
os.environ["CUDA_VISIBLE_DEVICES"]="0,1,2,3"

In [2]:
dataset = load_from_disk("./datasets/Rog-webqsp_test_rr")
case_1 = dataset[0]
case_2 = dataset[1]
dataset = dataset.select(range(2,len(dataset)))
dataset_1hop = dataset.filter(lambda x: x["gt_path_triplet"]is not None and len(x["gt_path_triplet"])==1)
dataset_2hop = dataset.filter(lambda x: x["gt_path_triplet"]is not None and len(x["gt_path_triplet"])==2)

In [3]:
print(dataset[0]['gn_path_triplet'])
print(dataset[0]['gt_path_triplet'])

[['JaMarcus Russell', 'people.person.gender', 'Male'], ['JaMarcus Russell', 'people.person.parents', 'Bobby Lloyd'], ['JaMarcus Russell', 'people.person.nationality', 'United States of America'], ['JaMarcus Russell', 'people.person.ethnicity', 'African American'], ['JaMarcus Russell', 'common.topic.image', 'JaMarcus Russell at Falcons at Raiders 11-2-08'], ['JaMarcus Russell', 'people.person.parents', 'Zina L. Russell-Anderson'], ['JaMarcus Russell', 'people.person.places_lived', 'm.03phjs_'], ['JaMarcus Russell', 'sports.pro_athlete.sports_played_professionally', 'm.0c550qk']]
[['JaMarcus Russell', 'people.person.place_of_birth', 'Mobile']]


In [4]:
def build_messages_1hop(example,no_rag,zero_shot,n_negative,fs_case_1,fs_case_2):

    q = example["question"]
    if(no_rag):
        messages = [
            {"role": "system", "content": "You are a helpful assistant"},
            {"role": "user", "content": f"{q}?"},
        ]
    else:
        def get_knowledge_text(data):
            triplet_list = data['gn_path_triplet'][:n_negative]+data['gt_path_triplet']
            random.shuffle(triplet_list)
            triplet_list = ["("+"->".join(t)+")" for t in triplet_list]
            txt="\n".join(triplet_list)
            return txt
        
        txt = get_knowledge_text(example)
        
        if zero_shot:
            messages = [
            {"role": "system", "content": "You are a helpful assistant that answers the user's question based on the user's knowledge."},
            {"role": "user", "content": f"""[Knowledge]
{txt}
[Question]
{q}?"""}
            ]      
        else:
            txt_1 = get_knowledge_text(fs_case_1)
            q_1 = fs_case_1["question"]
            answer_1 = fs_case_1["a_entity"][0]
            txt_2 = get_knowledge_text(fs_case_2)
            q_2 = fs_case_2["question"]
            answer_2 = fs_case_2["a_entity"][0]

            messages = [
            {"role": "system", "content": "You are a helpful assistant that answers the user's question based on the user's knowledge."},
            {"role": "user", "content": f"""[Knowledge]
{txt_1}
[Question]
{q_1}?"""},
            {"role": "assistant", "content": f"The answer is {answer_1}"},
            {"role": "user", "content": f"""[Knowledge]
{txt_2}
[Question]
{q_2}?"""},
            {"role": "assistant", "content": f"The answer is {answer_2}"},
            {"role": "user", "content": f"""[Knowledge]
{txt}
[Question]
{q}?"""},
            ]
    example["messages"] =messages
    return example

In [5]:

def set_file_handler(logger, path, level=logging.DEBUG, format="%(asctime)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s"):
    os.makedirs(os.path.dirname(path + "/run.log"), exist_ok=True)
    handler = logging.FileHandler(path + "/run.log")
    handler.setLevel(level)
    formatter = logging.Formatter(format)
    handler.setFormatter(formatter)
    logger.addHandler(handler)
    
logger = logging.getLogger("inference")
logger.setLevel(logging.DEBUG)
set_file_handler(logger, "./logs")

model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct", torch_dtype=torch.bfloat16, device_map="auto", attn_implementation="flash_attention_2")
tokenizer = AutoTokenizer.from_pretrained(model.config._name_or_path, padding_side="left")
model.eval()
tokenizer.pad_token = tokenizer.eos_token

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [6]:

from torch.utils.data import DataLoader
def build_prompt(example, tokenizer):
    prompt = tokenizer.apply_chat_template(example["messages"], tokenize=False, add_generation_prompt=True)
    prompt_tokens = len(tokenizer(prompt, add_special_tokens=False).input_ids)
    return {"prompt": prompt, "prompt_tokens": prompt_tokens}


def collate_fn(batch, tokenizer):
    prompt = [example["prompt"] for example in batch]
    inputs = tokenizer(prompt, add_special_tokens=False, padding=True, return_tensors="pt")
    return inputs


def generate(model, tokenizer, dataloader, logger, log_every, **kwargs):
    start = time.time()
    output_ids = []
    for i, inputs in tqdm(enumerate(dataloader, start=1)):
        inputs = inputs.to(model.device)
        with torch.no_grad():
            outputs = model.generate(**inputs, pad_token_id=tokenizer.eos_token_id, **kwargs)
        output_ids.extend(outputs[:, inputs["input_ids"].size(1) :].tolist())
        if i % log_every == 0:
            end = time.time()
            elapsed = end - start
            total = elapsed * (len(dataloader) / i)
            logger.info(f"Done {i}/{len(dataloader)} steps - {str(timedelta(seconds=int(elapsed)))}/{str(timedelta(seconds=int(total)))}.")
    return output_ids


def decode(example, tokenizer, feature):
    text = tokenizer.decode(example[feature + "_ids"], skip_special_tokens=True)
    return {feature: text}

dataset=dataset_1hop.map(partial(build_messages_1hop,no_rag=False,zero_shot=False,n_negative=2,fs_case_1=case_1,fs_case_2=case_2))
dataset = dataset.map(partial(build_prompt, tokenizer=tokenizer), num_proc=16)
dataloader = DataLoader(dataset, batch_size=32, shuffle=False, num_workers=16, collate_fn=partial(collate_fn, tokenizer=tokenizer), pin_memory=True)  # type: ignore
print(len(dataset),len(dataloader))

862 27


In [7]:

output_ids = generate(model, tokenizer, dataloader, logger, 10, max_new_tokens=20, do_sample=True, temperature=0.5, top_k=50, top_p=0.5)

dataset = dataset.add_column("model_answer_ids", output_ids)  # type: ignore
dataset = dataset.map(partial(decode, tokenizer=tokenizer, feature="model_answer"), num_proc=16)

27it [02:02,  4.53s/it]


Map (num_proc=16):   0%|          | 0/862 [00:00<?, ? examples/s]

In [9]:
hit=0
fail=0
for row in dataset:
    if row["a_entity"][0] in row["model_answer"]:
        hit+=1
    else:
        fail+=1
print(hit,fail)

410 452


In [10]:
print(dataset[0]["messages"])

[{'content': "You are a helpful assistant that answers the user's question based on the user's knowledge.", 'role': 'system'}, {'content': '[Knowledge]\n(Jamaica->book.book_subject.works->Culture and Customs of Jamaica)\n(Jamaica->location.country.languages_spoken->Jamaican English)\n(Jamaica->location.country.official_language->Jamaican English)\n[Question]\nwhat does jamaican people speak?', 'role': 'user'}, {'content': "The answer is ['Jamaican English', 'Jamaican Creole English Language']", 'role': 'assistant'}, {'content': '[Knowledge]\n(m.04j60kc->government.government_position_held.office_position_or_title->United States Representative)\n(James K. Polk->common.topic.image->James K. Polk)\n(James K. Polk->government.politician.government_positions_held->m.04j60kc)\n(James K. Polk->common.topic.image->James K. Polk)\n[Question]\nwhat did james k polk do before he was president?', 'role': 'user'}, {'content': "The answer is ['United States Representative', 'Governor of Tennessee', 

In [8]:

# model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"

# pipeline = transformers.pipeline(
#     "text-generation",
#     model=model_id,
#     model_kwargs={"torch_dtype": torch.bfloat16},
#     device_map="auto",
# )
# messages = [
#     {"role": "system", "content": "You are a helpful assistant"},
#     {"role": "user", "content": "Who are you?"},
# ]

# outputs = pipeline(
#     messages,
#     max_new_tokens=256,
# )
# print(outputs[0]["generated_text"][-1])