In [None]:
import json
import torch
import os
from transformers import AutoTokenizer, AutoModelForCausalLM
import re
from tqdm import tqdm
import glob

# Function to split Chinese text into sentences
def split_chinese_text(text):
    sentence_endings = re.compile(r'([。！？])')
    sentences = sentence_endings.split(text)
    sentences_with_endings = ["".join(i) for i in zip(sentences[0::2], sentences[1::2])]
    sentences_with_endings = [s.strip() for s in sentences_with_endings if s.strip()]
    return sentences_with_endings

# Load the model and tokenizer
MODEL_PATH = "meta-llama/Llama-3.1-8B"
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
model = AutoModelForCausalLM.from_pretrained(MODEL_PATH, torch_dtype=torch.float16)

# Function to compute perplexity for a single text
def compute_ppl(text):
    input_ids = tokenizer(text, return_tensors="pt").input_ids
    with torch.no_grad():
        outputs = model(input_ids, labels=input_ids)
        loss = outputs.loss
    perplexity = torch.exp(loss).item()
    return perplexity

# Load PPL evaluation data
def load_ppl_eval_data(file_path):
    with open(file_path, "r", encoding="utf-8") as file:
        data = json.load(file)
    return data

# Compute and print average PPL for each file
def compute_and_print_avg_ppl(input_directory):
    all_files = glob.glob(f"{input_directory}/*.json")
    for file_path in all_files:
        print(f"Processing file: {file_path}")
        data = load_ppl_eval_data(file_path)
        ppls = []
        for entry in tqdm(data):
            tgt_text = entry["predict"]
            ppl = compute_ppl(tgt_text)
            ppls.append(ppl)
        avg_ppl = sum(ppls) / len(ppls) if ppls else float('nan')
        print(f"Average PPL for {file_path}: {avg_ppl}")

# Example usage
input_directory = "./llama_results"  # Replace with your input directory
compute_and_print_avg_ppl(input_directory)
