In [5]:
import json
data_path = f"../../result/multitask_document_related.json"
with open(data_path, "r") as f:
    dataset = json.load(f)

In [2]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from tqdm import tqdm

# model_path = "/home/jovyan/hdfs-jmt-rungjoo-private/huggingface_models/JoSw-14/LoKuS-13B"
model_path = "/home/jovyan/hdfs-jmt-rungjoo-private/huggingface_models/AIDC-ai-business/Luban-13B"
# model_path = "/home/jovyan/hdfs-jmt-rungjoo-private/huggingface_models/uukuguy/speechless-llama2-luban-orca-platypus-13b" # 
tokenizer = AutoTokenizer.from_pretrained(model_path) 
model = AutoModelForCausalLM.from_pretrained(
    model_path,
    device_map="auto",
    torch_dtype=torch.float16
)

  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|██████████| 3/3 [01:13<00:00, 24.38s/it]


In [4]:
save_path = "/home/jovyan/hdfs-jmt-rungjoo-private/huggingface_models/uukuguy/speechless-llama2-luban-orca-platypus-13b"
model.save_pretrained(save_path)
tokenizer.save_pretrained(save_path)

('/home/jovyan/hdfs-jmt-rungjoo-private/huggingface_models/uukuguy/speechless-llama2-luban-orca-platypus-13b/tokenizer_config.json',
 '/home/jovyan/hdfs-jmt-rungjoo-private/huggingface_models/uukuguy/speechless-llama2-luban-orca-platypus-13b/special_tokens_map.json',
 '/home/jovyan/hdfs-jmt-rungjoo-private/huggingface_models/uukuguy/speechless-llama2-luban-orca-platypus-13b/tokenizer.json')

In [3]:
def make_prompt(query, pred_facets, method):
    if method == "post":
        one_shot = """### User:\nThe predicted facets for 'caesars atlantic city' are 'parking, hotels'. But the correct facets are 'caesars atlantic city events, caesars atlantic city jobs, caesars atlantic city parking'\n"""
        two_shot = """The predicted facets for 'vista, ca' are 'parking, hotels'. But the correct facets are 'weather, zip code, population, homes for sale'\n\n"""
        prompt = one_shot + two_shot + f"""As in the example above, modify the predicted facets.\nThe predicted facets for '{query}' are '{pred_facets}'. What are the correct facets?\n\n### Assistant:\nThe correct facets for '{query}' are"""    
    else: # unseen
        one_shot = """### User:\nThe facets for 'caesars atlantic city' are 'caesars atlantic city events, caesars atlantic city jobs, caesars atlantic city parking'\n"""
        two_shot = """The facets for 'vista, ca' are 'weather, zip code, population, homes for sale'\n\n"""
        prompt = one_shot + two_shot + f"""### Assistant:\nThe correct facets for '{query}' are"""    
    
    return prompt

In [18]:
import re
eng_rule = re.compile(r"\'[a-zA-Z\s,]+\'")

import time
for ind, data in dataset.items():
    if int(ind) > 8:
        st = time.time()
        query = data['query']
        pred_facet_list = data['pred']
        pred_facets = ", ".join(pred_facet_list)

        method = "post"
        prompt = make_prompt(query, pred_facets, method)
        print(prompt)
        
        inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
        
        label_str = ", ".join(data['label'])
        label_inputs = tokenizer(label_str, return_tensors="pt")
        label_len = label_inputs['input_ids'].shape[1]

        output = model.generate(**inputs, use_cache=True, max_new_tokens=int(label_len*2), temperature=0, top_p=1)
        
        output = tokenizer.decode(output[0], skip_special_tokens=True)
        correct_facets = output[len(prompt):]
        matches = eng_rule.findall(correct_facets.strip())
        correct_facet_list = [x.strip() for x in matches[0].strip("'").split(",") if x.strip() != ""]
        print(correct_facets)
        print(correct_facet_list)
        
        print("### Label:")
        print(data['label'])
        ed = time.time()
        print(ed-st)
        break

### User:
The predicted facets for 'caesars atlantic city' are 'parking, hotels'. But the correct facets are 'caesars atlantic city events, caesars atlantic city jobs, caesars atlantic city parking'
The predicted facets for 'vista, ca' are 'parking, hotels'. But the correct facets are 'weather, zip code, population, homes for sale'

As in the example above, modify the predicted facets.
The predicted facets for 'new caledonia' are 'population, new caledonia zip code'. What are the correct facets?

### Assistant:
The correct facets for 'new caledonia' are
 'geography, population, culture, economy, tourism, language, history'. The given predicted facets are too specific and do not encompass the diverse aspects of New Caledonia.
['geography', 'population', 'culture', 'economy', 'tourism', 'language', 'history']
### Label:
['new caledonia population', 'new caledonia flag', 'time in new caledonia', 'new caledonia news']
1.5268876552581787


In [21]:
test_rule = re.compile(r"\'.+\'")
matches = test_rule.findall(correct_facets.strip())
matches

["'geography, population, culture, economy, tourism, language, history'"]