In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict

from metrics import rouge_n, meteor, ter, chrf, wer, levenshtein_distance

In [None]:
def load_texts(predictions_path, references_path):
    with open(predictions_path, 'r', encoding='utf-8') as pred_file:
        predictions = [line.strip() for line in pred_file]
    with open(references_path, 'r', encoding='utf-8') as ref_file:
        references = [line.strip() for line in ref_file]
    return predictions, references

def calculate_metrics(prediction, reference):
    metrics = {}
    metrics['rouge1'] = rouge_n(prediction, [reference], n=1)['f1']
    metrics['rouge2'] = rouge_n(prediction, [reference], n=2)['f1']
    metrics['meteor'] = meteor(prediction, [reference])
    metrics['ter'] = ter(prediction, [reference])
    metrics['chrf'] = chrf(prediction, [reference])
    metrics['wer'] = wer(prediction, reference)
    metrics['levenshtein'] = levenshtein_distance(prediction, reference) / max(len(prediction), len(reference)) if max(len(prediction), len(reference)) > 0 else 0
    return metrics

def aggregate_metrics(predictions, references):
    aggregated = defaultdict(list)
    for pred, ref in zip(predictions, references):
        metrics = calculate_metrics(pred, ref)
        for key, value in metrics.items():
            aggregated[key].append(value)
    average_metrics = {key: np.mean(values) * 100 for key, values in aggregated.items()}
    return average_metrics, aggregated

def plot_comparative_average_metrics(mbart_metrics, indictrans_metrics, output_dir):
    metrics = list(mbart_metrics.keys())
    mbart_scores = list(mbart_metrics.values())
    indictrans_scores = list(indictrans_metrics.values())

    bar_width = 0.35
    index = np.arange(len(metrics))

    plt.figure(figsize=(12, 7))
    plt.barh(index, mbart_scores, bar_width, label='mBART', color='skyblue')
    plt.barh(index + bar_width, indictrans_scores, bar_width, label='IndicTrans', color='salmon')

    plt.xlabel('Score (%)')
    plt.yticks(index + bar_width / 2, metrics)
    plt.legend()
    plt.title('Comparative Average Metrics')
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'comparative_average_metrics.png'))
    plt.close()
    print("Saved comparative average metrics bar plot.")

def plot_comparative_metric_distributions(mbart_aggregated, indictrans_aggregated, output_dir):
    for metric, mbart_values in mbart_aggregated.items():
        indictrans_values = indictrans_aggregated[metric]
        
        plt.figure(figsize=(10, 5))
        plt.hist(mbart_values, bins=20, alpha=0.5, label='mBART', color='skyblue', edgecolor='black')
        plt.hist(indictrans_values, bins=20, alpha=0.5, label='IndicTrans', color='salmon', edgecolor='black')
        
        plt.xlabel('Score')
        plt.ylabel('Frequency')
        plt.title(f'{metric.capitalize()} Distribution Comparison')
        plt.legend(loc='upper right')
        plt.tight_layout()
        plt.savefig(os.path.join(output_dir, f'comparative_{metric}_distribution.png'))
        plt.close()
        print(f"Saved comparative histogram for {metric}.")

In [None]:
mbart_dir = "/shared/3/projects/national-culture/cache/independent/cache/mbart/checkpoint-3750"
indictrans_dir = "/shared/3/projects/national-culture/cache/independent/cache/indictrans/checkpoint-1875"

mbart_predictions_path = os.path.join(mbart_dir, 'mbart_predictions.txt')
mbart_references_path = os.path.join(mbart_dir, 'mbart_references.txt')
indictrans_predictions_path = os.path.join(indictrans_dir, 'indictrans_predictions.txt')
indictrans_references_path = os.path.join(indictrans_dir, 'indictrans_references.txt')

print("Loading predictions and references for mBART...")
mbart_predictions, mbart_references = load_texts(mbart_predictions_path, mbart_references_path)
print(f"Loaded {len(mbart_predictions)} predictions for mBART.")

print("Loading predictions and references for IndicTrans...")
indictrans_predictions, indictrans_references = load_texts(indictrans_predictions_path, indictrans_references_path)
print(f"Loaded {len(indictrans_predictions)} predictions for IndicTrans.")

print("Calculating metrics for mBART...")
mbart_average_metrics, mbart_aggregated_metrics = aggregate_metrics(mbart_predictions, mbart_references)

print("Calculating metrics for IndicTrans...")
indictrans_average_metrics, indictrans_aggregated_metrics = aggregate_metrics(indictrans_predictions, indictrans_references)

print("\nAverage Metrics for mBART:")
for metric, avg_score in mbart_average_metrics.items():
    print(f"{metric.capitalize()}: {avg_score:.2f}%")

print("\nAverage Metrics for IndicTrans:")
for metric, avg_score in indictrans_average_metrics.items():
    print(f"{metric.capitalize()}: {avg_score:.2f}%")

with open(os.path.join(mbart_dir, 'mbart_average_metrics.txt'), 'w') as f:
    for metric, avg_score in mbart_average_metrics.items():
        f.write(f"{metric.capitalize()}: {avg_score:.2f}%\n")

with open(os.path.join(indictrans_dir, 'indictrans_average_metrics.txt'), 'w') as f:
    for metric, avg_score in indictrans_average_metrics.items():
        f.write(f"{metric.capitalize()}: {avg_score:.2f}%\n")

print("Saved average metrics to files for both models.")

output_dir = "./"  # Save plots in the current directory
plot_comparative_average_metrics(mbart_average_metrics, indictrans_average_metrics, output_dir)

plot_comparative_metric_distributions(mbart_aggregated_metrics, indictrans_aggregated_metrics, output_dir)