In [2]:
import csv
import os
import sys

max_int = sys.maxsize
while True:
    try:
        csv.field_size_limit(max_int)
        break
    except OverflowError:
        max_int = max_int // 2

gold_stories = []
with open("data/processed/train_processed.csv", encoding="utf-8") as f:
    reader = csv.DictReader(f)
    for row in reader:
        gold = [' '.join(s.split()) for s in row['gold'].split('|')]
        gold_stories.append(gold)

gold_set = {tuple(sorted(story)): i for i, story in enumerate(gold_stories)}

def clean_model_output(model_csv, cleaned_csv, bad_rows_csv):
    cleaned_rows = []
    bad_rows = []

    with open(model_csv, encoding="utf-8") as f:
        reader = csv.DictReader(f)
        for row in reader:
            pred_raw = row['reordered_story']
            pred = [' '.join(s.split()) for s in pred_raw.split('|')]

            if len(pred) != 5:
                bad_rows.append({**row, "reason": "wrong sentence count"})
                continue

            if any(s.count('.') > 1 for s in pred):
                bad_rows.append({**row, "reason": "merged or corrupted sentence"})
                continue

            pred_key = tuple(sorted(pred))
            if pred_key in gold_set:
                gold_story_index = gold_set[pred_key]
                original_gold_story = gold_stories[gold_story_index]
                cleaned_rows.append({
                    "original_story": ' | '.join(original_gold_story),
                    "reordered_story": ' | '.join(pred)
                })
            else:
                bad_rows.append({**row, "reason": "not a permutation of any gold story"})
        
    os.makedirs("data/cleaned", exist_ok=True)

    with open(cleaned_csv, 'w', newline='', encoding='utf-8') as f:
        writer = csv.DictWriter(f, fieldnames=["original_story", "reordered_story"])
        writer.writeheader()
        writer.writerows(cleaned_rows)

    with open(bad_rows_csv, 'w', newline='', encoding='utf-8') as f:
        if bad_rows:
            writer = csv.DictWriter(f, fieldnames=list(bad_rows[0].keys()))
            writer.writeheader()
            writer.writerows(bad_rows)


clean_model_output(
    "data/processed/train_reordered_pairs_gpt5_nano.csv",
    "data/cleaned/train_cleaned_gpt5_nano.csv",
    "data/cleaned/train_bad_gpt5_nano.csv"
)

clean_model_output(
    "data/processed/train_reordered_pairs_llama3.csv",
    "data/cleaned/train_cleaned_llama3.csv",
    "data/cleaned/train_bad_llama3.csv"
)

clean_model_output(
    "data/processed/train_reordered_pairs_qwen.csv",
    "data/cleaned/train_cleaned_qwen.csv",
    "data/cleaned/train_bad_qwen.csv"
)

In [1]:
from kendall_tau import compute_kendall_tau

gpt5_nano_taus, gpt5_nano_dropped_rows = compute_kendall_tau("data/processed/train_processed.csv", "data/cleaned/train_cleaned_gpt5_nano.csv")
print("Average gpt5 nano Kendall tau:", sum(gpt5_nano_taus)/len(gpt5_nano_taus))
print("gpt5 nano dropped rows: ", gpt5_nano_dropped_rows)

llama3_taus, llama3_dropped_rows = compute_kendall_tau("data/processed/train_processed.csv", "data/cleaned/train_cleaned_llama3.csv")
print("Average llama3 Kendall tau:", sum(llama3_taus)/len(llama3_taus))
print("llama3 dropped rows: ", llama3_dropped_rows)

qwen_taus, qwen_dropped_rows = compute_kendall_tau("data/processed/train_processed.csv", "data/cleaned/train_cleaned_qwen.csv")
print("Average qwen Kendall tau:", sum(qwen_taus)/len(qwen_taus))
print("qwen dropped rows: ", qwen_dropped_rows)

Average gpt5 nano Kendall tau: 0.6959409594095943
gpt5 nano dropped rows:  2377
Average llama3 Kendall tau: 0.7600000000000003
llama3 dropped rows:  63372
Average qwen Kendall tau: 0.8255813953488369
qwen dropped rows:  10438


In [2]:
from story_position_acc import count_sentence_mismatches, get_sentence_2_3_mismatches

mismatches = count_sentence_mismatches("data/processed/train_reordered_pairs.csv")
print("Total sentence mismatches:", mismatches)
print("Average sentence accuracy per story:", [(1375-m) / 1375 for m in mismatches])

examples = get_sentence_2_3_mismatches("data/processed/train_reordered_pairs.csv", k=10)
for e in examples:
    print("ORIG: ", e["orig"])
    print("REORD:", e["reord"])
    print()


Total sentence mismatches: [188, 272, 239, 206, 120]
Average sentence accuracy per story: [0.8632727272727273, 0.8021818181818182, 0.8261818181818181, 0.8501818181818181, 0.9127272727272727]
ORIG:  ['I asked for a drink afterwards, but he told me no', 'I was forced to do several push-ups']
REORD: ['He told me to stop as soon as I fell on the floor', 'I was forced to do several push-ups']

ORIG:  ['She sighed in frustration that she had ruined it', 'Picking it up, she put it back in the car']
REORD: ['The soda hit the ground hard', 'Picking it up, she put it back in the car']

ORIG:  ['She had been playing with his puppy before he got angry', 'She ran for blocks without looking back']
REORD: ['Lucy was running from an angry old man', 'She ran for blocks without looking back']

ORIG:  ['His doctor asked Max if he slept through the night regularly', "Max couldn't remember the last time he had slept through the night"]
REORD: ['Max noticed he was always tired so he went to the doctor for h