In [4]:
import csv
import os
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

def normalize(s):
    """Normalizes a sentence by stripping and collapsing whitespace."""
    return " ".join(s.strip().split())

def get_sentence_list(story_string):
    """Splits a story string into a list of normalized sentences."""
    return [normalize(s) for s in story_string.split(' | ')]

def calculate_adjacent_sentence_accuracy(csv_path):
    """
    Calculates the Adjacent Sentence Accuracy (ASA) for the model outputs.
    ASA measures the percentage of adjacent sentence pairs (Si, S_i+1) that are 
    correctly ordered according to the gold standard.
    """
    
    total_adjacent_pairs = 0
    correct_adjacent_pairs = 0
    total_stories = 0
    
    out_of_range_count = 0
    wrong_num_sentences_count = 0

    try:
        with open(csv_path, encoding='utf-8') as f:
            reader = csv.DictReader(f)
            
            for row in reader:
                total_stories += 1
                
                gold_sentences = get_sentence_list(row['gold'])
                pred_sentences = get_sentence_list(row['model_reordered'])
                if any("INDEX_OUT_OF_RANGE" in s for s in pred_sentences):
                    out_of_range_count += 1
                    continue
                if len(pred_sentences) != len(gold_sentences):
                    wrong_num_sentences_count += 1
                    continue
                
                N = len(gold_sentences)
                if N <= 1:
                    continue 
                num_adjacent_pairs = N - 1
                total_adjacent_pairs += num_adjacent_pairs
                
                for i in range(N - 1):
                    pred_sentence_i = pred_sentences[i]
                    pred_sentence_i_plus_1 = pred_sentences[i+1]
                    
                    try:
                        gold_index_i = gold_sentences.index(pred_sentence_i)
                        correct_next_sentence = gold_sentences[gold_index_i + 1]
                        
                        if pred_sentence_i_plus_1 == correct_next_sentence:
                            correct_adjacent_pairs += 1
                    
                    except (ValueError, IndexError):
                        pass
                        
    except FileNotFoundError:
        print(f"Error: CSV file not found at {csv_path}")
        return 0, -1, -1, -1, -1 
    
    asa_score = correct_adjacent_pairs / total_adjacent_pairs if total_adjacent_pairs > 0 else 0
    return asa_score, correct_adjacent_pairs, total_adjacent_pairs, out_of_range_count, wrong_num_sentences_count



In [7]:
def visualize_all_metrics(models, asa_scores, pmr_scores, kendall_tau_scores):
    """
    Generates a bar plot comparing PMR, Kendall Tau, and ASA.
    """
    
    data = {
        'Model': models,
        'PMR': pmr_scores,
        'ASA': asa_scores,
        'Kendall_Tau': kendall_tau_scores,
    }

    df = pd.DataFrame(data)

    plt.figure(figsize=(12, 7))
    bar_width = 0.25
    r1 = np.arange(len(df['Model']))
    r2 = [x + bar_width for x in r1]
    r3 = [x + bar_width for x in r2]

    plt.bar(r1, df['PMR'], color='skyblue', width=bar_width, edgecolor='grey', label='PMR')
    plt.bar(r2, df['ASA'], color='lightgreen', width=bar_width, edgecolor='grey', label='Adjacent Sentence Accuracy (ASA)')
    plt.bar(r3, df['Kendall_Tau'], color='lightcoral', width=bar_width, edgecolor='grey', label='Average Kendall $\\tau$')
    plt.xlabel('Model', fontweight='bold')
    plt.ylabel('Score', fontweight='bold')
    plt.xticks([r + bar_width for r in range(len(df['Model']))], df['Model'])
    plt.title('Model Performance Comparison: PMR, ASA, and Kendall $\\tau$', fontweight='bold')
    plt.legend(loc='lower left')
    plt.ylim(0, 1.0)
    plt.grid(axis='y', linestyle='--')
    plt.tight_layout()
    output_path = 'full_comp.png'
    plt.savefig(output_path)
    plt.close()
    print(f"\nGenerated all-metrics comparison plot: {output_path}")


def main():
    GPT_NANO_PATH = "data/final_outputs/train_gpt5_nano_reordered.csv"
    LLAMA3_PATH = "data/final_outputs/train_llama3_reordered.csv"
    QWEN_PATH = "data/final_outputs/train_qwen_reordered.csv"
    
    gpt_nano_asa, gpt_nano_matches, gpt_nano_total_pairs, _, _ = calculate_adjacent_sentence_accuracy(GPT_NANO_PATH)
    llama3_asa, llama3_matches, llama3_total_pairs, _, _ = calculate_adjacent_sentence_accuracy(LLAMA3_PATH)
    qwen_asa, qwen_matches, qwen_total_pairs, _, _ = calculate_adjacent_sentence_accuracy(QWEN_PATH)

    overall_matches = gpt_nano_matches + llama3_matches + qwen_matches
    overall_total_pairs = gpt_nano_total_pairs + llama3_total_pairs + qwen_total_pairs
    overall_asa = overall_matches / overall_total_pairs if overall_total_pairs > 0 else 0

    print("--- Adjacent Sentence Accuracy (ASA) Results ---")
    print(f"GPT-5 Nano ASA: {gpt_nano_asa:.4f} ({gpt_nano_matches} '/' {gpt_nano_total_pairs} pairs correct)")
    print(f"LLaMA-3 ASA:    {llama3_asa:.4f} ({llama3_matches} '/' {llama3_total_pairs} pairs correct)")
    print(f"Qwen ASA:       {qwen_asa:.4f} ({qwen_matches} '/' {qwen_total_pairs} pairs correct)")
    print(f"----------------------------------------------")
    print(f"Overall ASA:    {overall_asa:.4f} ({overall_matches} '/' {overall_total_pairs} pairs correct)")
    
    visualize_all_metrics(
        models=['GPT-5 Nano', 'LLaMA-3', 'Qwen'],
        asa_scores=[gpt_nano_asa, llama3_asa, qwen_asa],
        pmr_scores=[0.4321, 0.0412, 0.5890], 
        kendall_tau_scores=[0.7253, 0.3449, 0.8361] 
    )


if __name__ == "__main__":
    try:
        main()
    except Exception as e:
        print(f"An unexpected error occurred: {e}")

--- Adjacent Sentence Accuracy (ASA) Results ---
GPT-5 Nano ASA: 0.6253 (6118 '/' 9784 pairs correct)
LLaMA-3 ASA:    0.2507 (68593 '/' 273652 pairs correct)
Qwen ASA:       0.7350 (13395 '/' 18224 pairs correct)
----------------------------------------------
Overall ASA:    0.2921 (88106 '/' 301660 pairs correct)

Generated all-metrics comparison plot: full_comp.png
