In [18]:
from datasets import load_dataset

acorn = load_dataset("json", data_files={"train": "/content/ACORN.jsonl"})["train"]
print("Rows:", len(acorn))
print("Keys:", acorn[0].keys())

Rows: 3500
Keys: dict_keys(['id', 'q_id', 'q_source', 'question', 'choices', 'label', 'e_id', 'explanation', 'triples', 'positives', 'negatives', 'e_source', 'voted_ratings', 'worker_ratings'])


In [19]:
row = acorn[0]
for k in ["question","choices","label","explanation","voted_ratings"]:
    print(k, "=>", row[k] if k!="explanation" else row[k][:200]+" ...")

question => John is an omnidiciplinarian. Where might he find success?
choices => ['ocean', 'working hard', 'various situations', 'michigan', 'awards ceremony']
label => 2
explanation => prefix denotes all ...
voted_ratings => {'supports': 2, 'overall': 1, 'well_written': 0, 'related': 1, 'factual': 1, 'new_info': 1, 'unnecessary_info': 0, 'contrastive': 0}


In [20]:
ACORN_PROMPT = """\
You are rating the QUALITY of a free-text explanation to a commonsense question,
using ACORN criteria (match the label space EXACTLY).

Question:
{question}

Choices (label the chosen letter):
{choices_block}

Explanation to rate:
\"\"\"{explanation}\"\"\"

Return ONLY a JSON object with these fields and these exact label sets:

{{
  "supports": "a|b|c|d|e|none",
  "overall_1to5": 1..5,
  "well_written": "Yes|No",
  "related": "Yes|No",
  "factual": "Yes|No|N/A",
  "new_info": "None|Some|Sufficient|Ample",
  "unnecessary_info": "Yes|No",
  "contrastive": "Yes|No"
}}

Guidelines (short):
- Supports: which choice the explanation argues for (if none applies, use "none").
- Overall: holistic 1–5 quality.
- Well-Written: fluency/grammar/coherence.
- Related: relevant to the Q and A?
- Factual: are stated facts generally true (N/A if no factual claims)?
- New Information: how much new info beyond Q/choices?
- Unnecessary Info: does it include irrelevant info?
- Contrastive: does it clearly contrast correct vs. other choices?
"""

def make_choices_block(choices):
    letters = "abcdefghijklmnopqrstuvwxyz"
    return "\n".join([f"{letters[i]}) {c}" for i, c in enumerate(choices)])

def build_acorn_user_prompt(row):
    return ACORN_PROMPT.format(
        question=row["question"],
        choices_block=make_choices_block(row["choices"]),
        explanation=row["explanation"]
    )


In [21]:
import json
from jsonschema import validate, ValidationError

ACORN_SCHEMA = {
  "type": "object",
  "properties": {
    "supports": {"type": "string", "enum": ["a","b","c","d","e","none"]},
    "overall_1to5": {"type": "integer", "minimum": 1, "maximum": 5},
    "well_written": {"type": "string", "enum": ["Yes","No"]},
    "related": {"type": "string", "enum": ["Yes","No"]},
    "factual": {"type": "string", "enum": ["Yes","No","N/A"]},
    "new_info": {"type": "string", "enum": ["None","Some","Sufficient","Ample"]},
    "unnecessary_info": {"type": "string", "enum": ["Yes","No"]},
    "contrastive": {"type": "string", "enum": ["Yes","No"]}
  },
  "required": ["supports","overall_1to5","well_written","related","factual",
               "new_info","unnecessary_info","contrastive"],
  "additionalProperties": False
}

def repair_minor_misses(d):
    # common normalizations
    if "support" in d and "supports" not in d:
        d["supports"] = d.pop("support")
    if "overall" in d and "overall_1to5" not in d:
        try:
            d["overall_1to5"] = int(d.pop("overall"))
        except: pass

    # lowercase supports; map weirds
    if "supports" in d and isinstance(d["supports"], str):
        low = d["supports"].strip().lower()
        # sometimes returns the text or index; map to letter when possible
        if low in ["a","b","c","d","e","none"]:
            d["supports"] = low
        elif low in ["n/a","na"]:
            d["supports"] = "none"

    # cast overall
    if "overall_1to5" in d:
        try: d["overall_1to5"] = int(d["overall_1to5"])
        except: pass

    return d

def rate_explanation_with_acorn(system_prompt, user_prompt, model="gpt-4o-mini", temperature=0.2):
    resp = client.responses.create(
        model=model,
        input=[
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_prompt + "\n\nReturn ONLY the JSON object."},
        ],
        temperature=temperature,
        max_output_tokens=400,
    )
    text = resp.output_text.strip()
    raw = extract_first_json_object(text)
    data = json.loads(raw)
    try:
        validate(instance=data, schema=ACORN_SCHEMA)
        return data
    except ValidationError:
        data = repair_minor_misses(data)
        validate(instance=data, schema=ACORN_SCHEMA)  # will raise if still bad
        return data


In [22]:
import os
from openai import OpenAI

os.environ["OPENAI_API_KEY"] = "sk..."
assert os.environ.get("OPENAI_API_KEY", "").startswith("sk-"), "Missing OPENAI_API_KEY"
client = OpenAI(api_key=os.environ["OPENAI_API_KEY"])


In [23]:
def extract_first_json_object(s: str) -> str:
    """
    Returns the first top-level JSON object found in s.
    Raises ValueError if none or braces unbalanced.
    """
    start = s.find("{")
    if start == -1:
        raise ValueError("No '{' found in model output.")
    depth = 0
    in_str = False
    esc = False
    for i in range(start, len(s)):
        c = s[i]
        if in_str:
            if esc:
                esc = False
            elif c == "\\":
                esc = True
            elif c == '"':
                in_str = False
        else:
            if c == '"':
                in_str = True
            elif c == "{":
                depth += 1
            elif c == "}":
                depth -= 1
                if depth == 0:
                    return s[start:i+1]
    raise ValueError("Unbalanced JSON braces in model output.")


In [24]:
import random, pandas as pd

random.seed(123)
idxs = random.sample(range(len(acorn)), k=25)  # keep small first to manage cost

PERSONA = "You are a careful but concise ACORN judge. Follow the rubric exactly; choose labels only from the allowed sets."

preds = []
for i in idxs:
    row = acorn[i]
    up = build_acorn_user_prompt(row)
    pred = rate_explanation_with_acorn(PERSONA, up)
    # keep key metadata for later join
    pred_record = {
        "id": row.get("id", i),
        "supports_pred": pred["supports"],
        "overall_pred": pred["overall_1to5"],
        "well_written_pred": pred["well_written"],
        "related_pred": pred["related"],
        "factual_pred": pred["factual"],
        "new_info_pred": pred["new_info"],
        "unnecessary_info_pred": pred["unnecessary_info"],
        "contrastive_pred": pred["contrastive"],
    }
    preds.append(pred_record)
    print(f"[{i}] OK → {pred_record}")

pred_df = pd.DataFrame(preds)
len(pred_df), pred_df.head(3)


[214] OK → {'id': 'test_0214', 'supports_pred': 'b', 'overall_pred': 4, 'well_written_pred': 'Yes', 'related_pred': 'Yes', 'factual_pred': 'Yes', 'new_info_pred': 'Some', 'unnecessary_info_pred': 'No', 'contrastive_pred': 'Yes'}
[1096] OK → {'id': 'test_1096', 'supports_pred': 'a', 'overall_pred': 4, 'well_written_pred': 'Yes', 'related_pred': 'Yes', 'factual_pred': 'Yes', 'new_info_pred': 'None', 'unnecessary_info_pred': 'No', 'contrastive_pred': 'No'}
[357] OK → {'id': 'test_0357', 'supports_pred': 'd', 'overall_pred': 4, 'well_written_pred': 'Yes', 'related_pred': 'Yes', 'factual_pred': 'Yes', 'new_info_pred': 'Some', 'unnecessary_info_pred': 'No', 'contrastive_pred': 'No'}
[3149] OK → {'id': 'test_3149', 'supports_pred': 'b', 'overall_pred': 4, 'well_written_pred': 'Yes', 'related_pred': 'Yes', 'factual_pred': 'Yes', 'new_info_pred': 'Some', 'unnecessary_info_pred': 'No', 'contrastive_pred': 'Yes'}
[1668] OK → {'id': 'test_1668', 'supports_pred': 'b', 'overall_pred': 4, 'well_writt

(25,
           id supports_pred  overall_pred well_written_pred related_pred  \
 0  test_0214             b             4               Yes          Yes   
 1  test_1096             a             4               Yes          Yes   
 2  test_0357             d             4               Yes          Yes   
 
   factual_pred new_info_pred unnecessary_info_pred contrastive_pred  
 0          Yes          Some                    No              Yes  
 1          Yes          None                    No               No  
 2          Yes          Some                    No               No  )

In [26]:
import pandas as pd
from scipy.stats import spearmanr

letters = "abcdefghijklmnopqrstuvwxyz"

def norm_yesno(x, *, default="No"):
    # Accept "Yes"/"No", booleans, or weird casing
    if isinstance(x, bool):
        return "Yes" if x else "No"
    if isinstance(x, str):
        t = x.strip().lower()
        if t in ("yes","y","true","1"): return "Yes"
        if t in ("no","n","false","0"): return "No"
        if t in ("n/a","na","none",""): return default
    return default

def norm_factual(x):
    # Valid set: "Yes", "No", "N/A"
    if isinstance(x, str):
        t = x.strip().lower()
        if t in ("yes","true","1"): return "Yes"
        if t in ("no","false","0"): return "No"
        if t in ("n/a","na","none","null",""): return "N/A"
    if isinstance(x, bool):
        return "Yes" if x else "No"
    return "N/A"

def norm_new_info(x):
    # Valid set: "None","Some","Sufficient","Ample"
    if isinstance(x, str):
        t = x.strip().lower()
        if t in ("none","no","n/a","na",""): return "None"
        if t in ("some","a little","limited"): return "Some"
        if t in ("sufficient","adequate"): return "Sufficient"
        if t in ("ample","lots","a lot","substantial","extensive"): return "Ample"
    if x is None: return "None"
    return "Some"  # safe fallback

def to_letter_from_support(support_value, choices):
    """
    Map human 'supports' value (could be 'a', 0, 'The red option', None) to 'a'..'e' or 'none'.
    """
    # 1) None-like
    if support_value is None:
        return "none"

    # 2) string letter
    if isinstance(support_value, str):
        t = support_value.strip().lower()
        if t in ("none","n/a","na",""): return "none"
        if t in list(letters[:len(choices)]):  # "a","b",...
            return t
        # maybe it's a numeric string
        if t.isdigit():
            idx = int(t)
            if 0 <= idx < len(choices):
                return letters[idx]
        # maybe it's the actual choice text -> fuzzy exact match
        # try exact, case-insensitive
        for i, c in enumerate(choices):
            if t == str(c).strip().lower():
                return letters[i]
        return "none"

    # 3) integer index
    if isinstance(support_value, int):
        if 0 <= support_value < len(choices):
            return letters[support_value]
        return "none"

    # 4) unknown type
    return "none"

def row_to_gold(r):
    vr = r.get("voted_ratings", {}) or {}
    choices = r.get("choices", []) or []

    # supports can be letter, index, or text
    supports_raw = vr.get("supports", None)
    supports_letter = to_letter_from_support(supports_raw, choices)

    # overall may be str or int
    overall_raw = vr.get("overall", None)
    try:
        overall_int = int(overall_raw)
    except:
        overall_int = 3  # neutral fallback

    # keys might vary slightly across dumps
    new_info_raw = vr.get("new_information", vr.get("new_info", None))
    unnecessary_raw = vr.get("unnecessary_information", vr.get("unnecessary_info", None))

    return {
        "id": r.get("id", None),
        "supports_gold": supports_letter,
        "overall_gold": overall_int,
        "well_written_gold": norm_yesno(vr.get("well_written", "No")),
        "related_gold": norm_yesno(vr.get("related", "No")),
        "factual_gold": norm_factual(vr.get("factual", "N/A")),
        "new_info_gold": norm_new_info(new_info_raw),
        "unnecessary_info_gold": norm_yesno(unnecessary_raw, default="No"),
        "contrastive_gold": norm_yesno(vr.get("contrastive", "No")),
    }

# Build gold for the sampled indices
gold_rows = [row_to_gold(acorn[i]) for i in idxs]
gold_df = pd.DataFrame(gold_rows)

# Join predictions and gold
joined = pd.concat([pred_df.reset_index(drop=True), gold_df.reset_index(drop=True)], axis=1)

# Accuracy for categoricals
def acc(col_pred, col_gold):
    return float((joined[col_pred] == joined[col_gold]).mean())

metrics = {
    "supports_acc": acc("supports_pred", "supports_gold"),
    "well_written_acc": acc("well_written_pred", "well_written_gold"),
    "related_acc": acc("related_pred", "related_gold"),
    "factual_acc": acc("factual_pred", "factual_gold"),
    "new_info_acc": acc("new_info_pred", "new_info_gold"),
    "unnecessary_info_acc": acc("unnecessary_info_pred", "unnecessary_info_gold"),
    "contrastive_acc": acc("contrastive_pred", "contrastive_gold"),
}

# Spearman correlation for 'overall'
rho, p = spearmanr(joined["overall_pred"], joined["overall_gold"])
metrics["overall_spearman_r"] = float(rho)
metrics["overall_spearman_p"] = float(p)

print("=== Agreement vs ACORN majority-vote (n=%d) ===" % len(joined))
for k,v in metrics.items():
    print(f"{k:28s} {v:.3f}")

joined.head(5)


=== Agreement vs ACORN majority-vote (n=25) ===
supports_acc                 0.920
well_written_acc             0.160
related_acc                  0.160
factual_acc                  0.160
new_info_acc                 0.640
unnecessary_info_acc         0.880
contrastive_acc              0.600
overall_spearman_r           0.873
overall_spearman_p           0.000


Unnamed: 0,id,supports_pred,overall_pred,well_written_pred,related_pred,factual_pred,new_info_pred,unnecessary_info_pred,contrastive_pred,id.1,supports_gold,overall_gold,well_written_gold,related_gold,factual_gold,new_info_gold,unnecessary_info_gold,contrastive_gold
0,test_0214,b,4,Yes,Yes,Yes,Some,No,Yes,test_0214,b,3,No,No,,Some,No,No
1,test_1096,a,4,Yes,Yes,Yes,,No,No,test_1096,a,3,No,No,,Some,No,No
2,test_0357,d,4,Yes,Yes,Yes,Some,No,No,test_0357,d,3,No,No,,Some,No,No
3,test_3149,b,4,Yes,Yes,Yes,Some,No,Yes,test_3149,b,4,No,No,,Some,No,No
4,test_1668,b,4,Yes,Yes,Yes,Some,No,Yes,test_1668,b,5,No,No,,Some,No,No
