In [None]:
import os
import torch
import matplotlib.pyplot as plt
import numpy as np
from datasets import load_dataset
from collections import defaultdict
from transformers import AutoTokenizer

from metrics import rouge_n, meteor, ter, chrf, wer, levenshtein_distance

In [None]:
def load_predictions(predictions_path):
    with open(predictions_path, 'r', encoding='utf-8') as pred_file:
        predictions = [line.strip() for line in pred_file]
    return predictions

def load_references_from_test_split(dataset_name="rahular/itihasa", split="test"):
    dataset = load_dataset(dataset_name, split=split)
    references = [example["translation"]["sn"] for example in dataset]
    return references

def calculate_metrics(prediction, reference):
    metrics = {}
    metrics['rouge1'] = rouge_n(prediction, [reference], n=1)['f1']
    metrics['rouge2'] = rouge_n(prediction, [reference], n=2)['f1']
    metrics['meteor'] = meteor(prediction, [reference])
    metrics['ter'] = ter(prediction, [reference])
    metrics['chrf'] = chrf(prediction, [reference])
    metrics['wer'] = wer(prediction, reference)
    metrics['levenshtein'] = levenshtein_distance(prediction, reference) / max(len(prediction), len(reference)) if max(len(prediction), len(reference)) > 0 else 0
    return metrics

def load_references_from_test_split(dataset_name="rahular/itihasa", split="test"):
    dataset = load_dataset(dataset_name, split=split)
    references = [example["translation"]["sn"] for example in dataset if example["translation"]["sn"].strip()]
    return references

def aggregate_metrics(predictions, references):
    aggregated = defaultdict(list)
    
    for idx, (pred, ref) in enumerate(zip(predictions, references), start=1):
        if not ref.strip():  # Skip if reference is empty
            print(f"Skipping empty reference at index {idx}")
            continue
            
        metrics = calculate_metrics(pred, ref)
        for key, value in metrics.items():
            aggregated[key].append(value)

    average_metrics = {key: np.mean(values) * 100 for key, values in aggregated.items()}
    return average_metrics, aggregated

def plot_average_metrics(average_metrics, output_dir):
    keys = list(average_metrics.keys())
    values = list(average_metrics.values())

    plt.figure(figsize=(10, 6))
    plt.barh(keys, values, color='skyblue')
    plt.xlabel('Score (%)')
    plt.title('Average Metrics')
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'average_metrics.png'))
    plt.close()
    print("Saved bar plot of average metrics.")

def plot_metric_distributions(aggregated_metrics, output_dir):
    for metric, values in aggregated_metrics.items():
        plt.figure(figsize=(8, 4))
        plt.hist(values, bins=20, color='skyblue', edgecolor='black')
        plt.xlabel('Score')
        plt.ylabel('Frequency')
        plt.title(f'{metric.capitalize()} Distribution')
        plt.tight_layout()
        plt.savefig(os.path.join(output_dir, f'{metric}_distribution.png'))
        plt.close()
        print(f"Saved histogram for {metric}.")

In [None]:
model_dir = "/shared/3/projects/national-culture/cache/independent/cache/input-to-parse/checkpoint-1875"
predictions_path = os.path.join(model_dir, 'i2p_all_predictions.txt')
output_dir = model_dir

print("Loading predictions and references...")
predictions = load_predictions(predictions_path)
references = load_references_from_test_split()
print(f"Loaded {len(predictions)} predictions and {len(references)} references.")

print("Calculating metrics...")
average_metrics, aggregated_metrics = aggregate_metrics(predictions, references)

print("\nAverage Metrics:")
for metric, avg_score in average_metrics.items():
    print(f"{metric.capitalize()}: {avg_score:.2f}%")

with open(os.path.join(output_dir, 'average_metrics.txt'), 'w') as f:
    for metric, avg_score in average_metrics.items():
        f.write(f"{metric.capitalize()}: {avg_score:.2f}%\n")
print("Saved average metrics to file.")

plot_average_metrics(average_metrics, output_dir)

plot_metric_distributions(aggregated_metrics, output_dir)

Loading predictions and references...
Loaded 11722 predictions and 11721 references.
Calculating metrics...

Average Metrics:
Rouge1: 8.69%
Rouge2: 1.01%
Meteor: 4.20%
Ter: 118.06%
Chrf: 25.64%
Wer: 118.06%
Levenshtein: 73.01%
Saved average metrics to file.
Saved bar plot of average metrics.
Saved histogram for rouge1.
Saved histogram for rouge2.
Saved histogram for meteor.
Saved histogram for ter.
Saved histogram for chrf.
Saved histogram for wer.
Saved histogram for levenshtein.
