In [None]:
import os
import openai
import numpy as np
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity

openai.api_key = os.getenv("OPENAI_API_KEY")
smooth = SmoothingFunction().method1
model = SentenceTransformer('all-MiniLM-L6-v2')


def score_rationale_quality(rationale, question):
    """
    Prompts GPT to score rationale quality on coherence, correctness, and relevance.
    Returns scores from 1 to 5.
    """
    prompt = f"""You are a helpful AI grader. Evaluate the following rationale for the given question.
    
    Question: {question}
    Rationale: {rationale}

    Please rate the rationale on a scale from 1 to 5 for:
    1. Coherence (how logically consistent the reasoning is)
    2. Correctness (whether the facts stated are accurate)
    3. Relevance (whether the reasoning is focused on the question)

    Output the scores as a Python list: [coherence, correctness, relevance]
    """

    response = openai.ChatCompletion.create(
        model="gpt-4",
        messages=[{"role": "system", "content": "You are a helpful grader."},
                  {"role": "user", "content": prompt}],
        temperature=0.2
    )
    scores = eval(response['choices'][0]['message']['content'].strip())
    return scores


def compute_bleu_diversity(rationales):
    """
    Computes average pairwise BLEU score (lower = more diverse).
    """
    n = len(rationales)
    scores = []
    for i in range(n):
        for j in range(i + 1, n):
            ref = [rationales[i].split()]
            hyp = rationales[j].split()
            score = sentence_bleu(ref, hyp, smoothing_function=smooth)
            scores.append(score)
    return np.mean(scores)


def compute_embedding_diversity(rationales):
    """
    Computes mean cosine similarity (lower = more diverse).
    """
    embeddings = model.encode(rationales, convert_to_tensor=True)
    sim_matrix = cosine_similarity(embeddings)
    upper_tri = sim_matrix[np.triu_indices_from(sim_matrix, k=1)]
    return np.mean(upper_tri)


def evaluate_rationales(rationales, questions=None):
    """
    Runs full evaluation: quality and diversity.
    """
    quality_scores = []
    if questions is None:
        questions = [""] * len(rationales)

    print("Scoring rationale quality...")
    for rationale, question in zip(rationales, questions):
        scores = score_rationale_quality(rationale, question)
        quality_scores.append(scores)

    quality_scores = np.array(quality_scores)
    avg_quality = quality_scores.mean(axis=0)

    print("Computing diversity metrics...")
    bleu_div = compute_bleu_diversity(rationales)
    embed_div = compute_embedding_diversity(rationales)

    return {
        "average_quality": {
            "coherence": avg_quality[0],
            "correctness": avg_quality[1],
            "relevance": avg_quality[2],
        },
        "diversity": {
            "bleu_similarity": bleu_div,
            "embedding_similarity": embed_div
        }
    }

In [None]:
import json 
def extract_rationales_from_jsonl(path):
    rationales = []
    questions = []

    with open(path, 'r') as f:
        for line in f:
            entry = json.loads(line)

            full_text = entry.get("generated_text", "")
            prompt_text = entry.get("prompt", "")

            # Extract the last question from the prompt
            q_start = prompt_text.strip().rfind("Q:")
            q_section = prompt_text[q_start:].strip() if q_start != -1 else ""
            question = q_section.split("<|im_end|>")[0].strip()

            # Extract rationale from generated_text
            rationale_start = full_text.rfind("Q:")
            if rationale_start != -1:
                rationale_block = full_text[rationale_start:]
                rationale_lines = rationale_block.split("\nA:")
                if len(rationale_lines) > 1:
                    rationale = rationale_lines[1].strip()
                else:
                    rationale = "MISSING"
            else:
                rationale = "MISSING"

            # Clean weird prefixes
            rationale = rationale.replace("ystem", "").strip()

            if rationale:
                rationales.append(rationale)
                questions.append(question)

    return rationales, questions

In [None]:
# Dummy test
rationales = [
    "The grape is placed in the grocery cart before checkout.",
    "You put your grapes in the cart before you go to the cashier.",
    "Before buying grapes, they are placed in the grocery cart."
]

questions = [
    "Where do you put your grapes just before checking out?",
] * len(rationales)

result = evaluate_rationales(rationales, questions)
print(result)

In [None]:
# Load rationales + questions from given file
JSONL_FILE_PATH = ""
rationales, questions = extract_rationales_from_jsonl(JSONL_FILE_PATH)
results = evaluate_rationales(rationales, questions)
results 