# Тестирование перплексии, как отборщика плохих переводов
https://github.com/psytechlab/empathy_dataset_transfer/issues/24

In [None]:
!nvidia-smi

In [2]:
import sys
import os.path as osp
import os

current_dir = os.getcwd()

parent_dir = os.path.abspath(os.path.join(current_dir, '..'))
sys.path.append(parent_dir)
sys.path.append(current_dir + '/notebooks')
sys.path.append(current_dir)
sys.path.append(current_dir + '/data')

In [None]:
!pip install transformers evaluate tqdm -q

In [4]:
import json
from pathlib import Path
import torch
from evaluate import load
from tqdm import tqdm
import matplotlib.pyplot as plt

2025-06-29 14:13:51.950539: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1751206432.400509      35 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1751206432.526887      35 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [5]:
def read_json(path):
    return json.load(Path(path).open())

def save_json(obj, path):
    json.dump(obj, Path(path).open("w", encoding="utf-8"), indent=4, ensure_ascii=False)

def read_file(path: str):
	return Path(path).open().read()

In [6]:
data_to_check = read_json("/kaggle/input/esconv/esconv_translations_all_normalize.json")

In [7]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_NAMES = {
    # "aeonium": "aeonium/Aeonium-v1-Base-1B",
    "rugpt3small": "ai-forever/rugpt3small_based_on_gpt2",
    # "rugpt3medium": "ai-forever/rugpt3medium_based_on_gpt2",
    # "rugpt3large": "ai-forever/rugpt3large_based_on_gpt2"
}
DEVICE

'cuda'

In [8]:
def process_data(data, model_id):
    print(f"\nProcessing with {model_id}")
    perplexity = load("perplexity", module_type="metric")
    
    texts = [item["text_rus"] for item in data]
    ids = [item["id"] for item in data]
    original_texts = [item["text"] for item in data]
    
    batch_size = 32 if DEVICE == "cuda" else 8
    results = []
    
    for i in tqdm(range(0, len(texts), batch_size), desc="Calculating perplexity"):
        batch_texts = texts[i:i+batch_size]
        try:
            batch_results = perplexity.compute(
                predictions=batch_texts,
                model_id=model_id,
                batch_size=batch_size,
                device=DEVICE
            )
            
            for j, ppl in enumerate(batch_results["perplexities"]):
                results.append({
                    "id": ids[i+j],
                    "text": original_texts[i+j],
                    "text_rus": batch_texts[j],
                    "perplexity": ppl
                })
        except Exception as e:
            print(f"Error processing batch {i//batch_size}: {str(e)}")
            for j in range(len(batch_texts)):
                results.append({
                    "id": ids[i+j],
                    "text": original_texts[i+j],
                    "text_rus": batch_texts[j],
                    "perplexity": None
                })
    
    return results

In [None]:
def visualize_results(results, model_id):
    perplexities = [item["perplexity"] for item in results if item["perplexity"] is not None]
    
    plt.figure(figsize=(10, 6))
    plt.boxplot(perplexities)
    plt.title(f"Perplexity distribution ({model_id})")
    plt.ylabel("Perplexity")
    plt.savefig(f"perplexity_boxplot_{model_id.replace('/', '_')}.png")
    plt.close()
    
    valid_results = [item for item in results if item["perplexity"] is not None]
    sorted_results = sorted(valid_results, key=lambda x: x["perplexity"])
    
    top_low = sorted_results[:50]
    top_high = sorted_results[-50:][::-1]
    
    model_name = model_id.replace('/', '_')
    with open(f"top_low_perplexity_{model_name}.txt", "w", encoding="utf-8") as f:
        for item in top_low:
            f.write(f"ID: {item['id']}\n")
            f.write(f"Perplexity: {item['perplexity']:.2f}\n")
            f.write(f"Original: {item['text']}\n")
            f.write(f"Translation: {item['text_rus']}\n")
            f.write("\n" + "="*80 + "\n")
    
    with open(f"top_high_perplexity_{model_name}.txt", "w", encoding="utf-8") as f:
        for item in top_high:
            f.write(f"ID: {item['id']}\n")
            f.write(f"Perplexity: {item['perplexity']:.2f}\n")
            f.write(f"Original: {item['text']}\n")
            f.write(f"Translation: {item['text_rus']}\n")
            f.write("\n" + "="*80 + "\n")

In [None]:
for model_name in MODEL_NAMES:
    model_id = MODEL_NAMES[model_name]
    results = process_data(data_to_check, model_id)
    
    with open(f"full_results_{model_name}.json", "w", encoding="utf-8") as f:
        json.dump(results, f, ensure_ascii=False, indent=2)
    
    visualize_results(results, model_id)

In [11]:
perplexities = [item["perplexity"] for item in results if item["perplexity"] is not None]

In [None]:
import numpy as np
from scipy import stats

mean_perp = np.mean(perplexities)
median_perp = np.median(perplexities)
std_perp = np.std(perplexities)
min_perp = np.min(perplexities)
max_perp = np.max(perplexities)
percentile_25 = np.percentile(perplexities, 25)
percentile_75 = np.percentile(perplexities, 75)

print(f"Mean perplexity: {mean_perp:.2f}")
print(f"Median: {median_perp:.2f}")
print(f"STD perplexity: {std_perp:.2f}")
print(f"Min: {min_perp:.2f}, Максимум: {max_perp:.2f}")
print(f"25%: {percentile_25:.2f}, 75-й процентиль: {percentile_75:.2f}")

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

plt.figure(figsize=(8, 5))
sns.histplot(perplexities, bins=20, kde=True)
plt.title("Perplexities distribution")
plt.xlabel("Perplexity")
plt.ylabel("Frequency")
plt.show()

plt.figure(figsize=(6, 4))
sns.boxplot(perplexities)
plt.title("Boxplot")
plt.show()