In [None]:
import sys
import os

sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))

In [None]:
from pathlib import Path
from src.utils.config_loader import load_config

base_dir = Path(os.getcwd()).parent

config = load_config(base_dir / 'secrets.yaml')

In [None]:
import torch
import transformers

class LlamaForSummarization:
    def __init__(self, model_path: str):
        self.model_id = model_path
        self.pipeline = transformers.pipeline(
            'text-generation',
            model=self.model_id,
            model_kwargs={'torch_dtype': torch.float16} # FP16 for faster inference
        )
        tokenizer = self.pipeline.tokenizer
        tokenizer.pad_token_id = tokenizer.eos_token_id 

        self.terminators = tokenizer.eos_token_id
  
    def format_prompt(self, v1, v2, ref):
        prompt = f"""
    You are a helpful assistant. Your task is to summarize the changes between two versions of a text.

    The first version is: {v1}
    
    The second version is: {v2}
    
    Please provide a summary of the changes between the two versions.


    """
        return prompt

    def summarize(self, v1, v2, ref, max_tokens=50, temperature=0.0, top_p=0.9):

        prompt = self.format_prompt(v1, v2, ref) 

        outputs = self.pipeline(
            prompt,
            max_new_tokens=max_tokens,
            eos_token_id=self.terminators,
            pad_token_id=self.terminators,
            do_sample=False, 
            temperature=temperature, 
            top_p=top_p 
        )
        
        generated_text = outputs[0]['generated_text'][len(prompt):].strip()
       
        return generated_text

In [None]:
import json

with open(base_dir / 'data/val_set.json', 'r') as f:
    data = json.load(f)

In [None]:
from tqdm import tqdm

model_path = 'meta-llama/Meta-Llama-3-8B-Instruct' 
model = LlamaForSummarization(model_path)

results = []

data_sample = data[:1]

for item in tqdm(data_sample):

    version_1 = item['version_1']
    version_2 = item['version_2']
    ref_summary = item['ref_summary'] 

    output = model.summarize(version_1, version_2, ref_summary)

    results.append({
            'id': item['id'],
            'ref_summary': ref_summary,
            'model_summary': output['generated_text'],
        })

In [None]:
with open(base_dir / 'results/Meta-Llama-3-8B-Instruct_ZEROSHOT.jsonl', 'w') as f:
    for result in results:
        json.dump(result, f, ensure_ascii=False)
        f.write('\n')