In [19]:
#|default_exp 45_llama-for-conflation-quality

In [17]:
#| export
import torch, json
from transformers import AutoTokenizer, AutoModelForCausalLM

from sugar.core import load_raw_file

## Prompt

You are an expert evaluator of entity conflation.

Task:
Determine whether the given entities belong to a coherent cluster (i.e., they represent the same underlying entity).

Instructions:
1. Provide your answer **only** in valid JSON format.
2. The JSON must contain exactly two keys:
   - "score": an integer from 1 (very incoherent) to 5 (highly coherent).
   - "reason": a short explanation for the score.
3. Ensure that the output can be parsed by:
   ```python
   import json
   content = json.loads(output)
   ```
where `content` has exactly two keys: "score" and "reason".

Example:
Input: Apple iPhone 14 Pro || Samsung Galaxy S23 || Google Pixel 7
Output:
{
"score": 1,
"reason": "Entities belong to the same category (smartphones) but are from different manufacturers and product lines."
}

Now evaluate the following cluster:
{cluster_repr}

## Driver

In [13]:
#| export
PROMPT_TEMPLATE = """You are an expert evaluator of entity conflation.

Task:
Determine whether the given entities belong to a coherent cluster (i.e., they represent the same underlying entity).

Instructions:
1. Provide your answer **only** in valid JSON format.
2. The JSON must contain exactly two keys:
   - "score": an integer from 1 (very incoherent) to 5 (highly coherent).
   - "reason": a short explanation for the score.
3. Ensure that the output can be parsed by:
   ```python
   import json
   content = json.loads(output)
   ```
where `content` has exactly two keys: "score" and "reason".

Example:
Input: Apple iPhone 14 Pro || Samsung Galaxy S23 || Google Pixel 7
Output:
{
"score": 1,
"reason": "Entities belong to the same category (smartphones) but are from different manufacturers and product lines."
}

Now evaluate the following cluster:
"""

In [15]:
#| export
def make_prompts(clusters):
    return [PROMPT_TEMPLATE + o for o in clusters]
    

In [16]:
#| export
def evaluate_clusters_in_batch(clusters, batch_size=4, max_new_tokens=128):
    prompts = make_prompts(clusters)
    
    inputs = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True).to(model.device)
    
    generations = []
    for i in range(0, len(prompts), batch_size):
        batch_input_ids = inputs["input_ids"][i:i+batch_size]
        batch_attention_mask = inputs["attention_mask"][i:i+batch_size]
    
        outputs = model.generate(
            input_ids=batch_input_ids,
            attention_mask=batch_attention_mask,
            max_new_tokens=max_new_tokens,
            do_sample=False
        )
    
        batch_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)
        for prompt, full_output in zip(prompts[i:i+batch_size], batch_texts):
            generations.append(json.loads(full_output.split(prompt)[-1].strip()))
    
    return generations
    

In [18]:
#| export
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--input_file', type=str, required=True)
    parser.add_argument('--output_file', type=str, required=True)
    return parser.parse_known_args()[0]
    

In [None]:
#| export
if __name__ == '__main__':
    args = parse_args()
    
    model_name = "meta-llama/Meta-Llama-3-8B-Instruct"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.float16,
        device_map="auto"
    )

    ids, txts = load_raw_file(args.input_file)
    generations = evaluate_clusters_in_batch(txts)
    
    for i,o in zip(ids, generations): o['identifier'] = i

    with open(args.output_file, 'w') as file:
        json.dump(generations, file)
        