In [14]:
import json
with open("./shortanswer_ratings_cache.json", 'r', encoding='utf-8') as f:
    sa_cache = json.load(f)

sa_qdict = {}
for qaid, qdata in sa_cache.items():
    qid = qdata['qid']
    if qid not in sa_qdict:
        sa_qdict[qid] = {"question": qdata['question'], "correct_answer": qdata['correct_answer']}

In [4]:
with open("./compiled_results_sqa/grok-3-latest_phase1_compiled.json", 'r', encoding='utf-8') as f:
    claude_results = json.load(f)

for qid in claude_results["results"].keys():
    if qid not in sa_qdict:
        print(f"Warning: {qid} not found in sa_qdict")


In [None]:
import re
from base_game_class import BaseGameClass

judge_model = "claude-opus-4-20250514"
sa_to_mc_file = "./SimpleMC.jsonl"

fout = open(sa_to_mc_file, 'a', encoding='utf-8')

judge = BaseGameClass(subject_id=None, subject_name=judge_model, is_human_player=False, log_dir=None)
sysprompt=""
prompt = """I need your help turning a short-answer quiz into a multiple-choice quiz. I'm going to show you a question and its correct answer, and I want you to generate three distractors. 
Each distractor should be a plausible answer that is NOT the correct answer. Each should be the same \"type\" of answer as the correct answer (e.g., date, person name, number, etc), and follow the format of the correct answer.
Output each distractor a separate line, and do not include any other text. Your entire response should be just the distractors, one per line.
Here is the question and correct answer:

Question: {question}

Correct Answer: {correct_answer}

Distractors:
"""
for ctr, (qid, qdata) in enumerate(sa_qdict.items()):
    question = qdata['question']
    correct_answer = qdata['correct_answer'].strip()
    print(f"Processing question {ctr+1}, ID: {qid}")
    
    while True:
        resp, _, _ = judge._get_llm_answer(options=None, q_text=prompt.format(question=question, correct_answer=correct_answer), message_history=[], keep_appending=False, setup_text=sysprompt, MAX_TOKENS=None, temp=1.0)
        ans_list = re.split(r'\n+', resp) if resp else []
        ans_set = set([a.strip().upper() for a in ans_list] + [correct_answer.upper()]) 
        if len(ans_set) == 4 and all(ans.strip() for ans in ans_list):
            break
        else:
            print(f"Invalid response format for question {qid}. Retrying...")
    fout.write(json.dumps({"qid": qid, "question": question, "correct_answer": correct_answer, "distractors": ans_list}, ensure_ascii=False) + "\n")
    fout.flush()


Provider: Anthropic
Processing question 499, ID: sqa_test_2469
In model_call, provider=Anthropic, attempt=1
Processing question 500, ID: sqa_test_1573
In model_call, provider=Anthropic, attempt=1


In [27]:
from load_and_format_datasets import load_and_format_dataset
qs=load_and_format_dataset("SimpleMC")


Attempting to load SimpleMC...
Dataset loaded successfully.
Attempting to load SimpleQA (test split)...
Dataset loaded successfully.
Formatting 4326 questions...
Successfully formatted 4326 unique questions from SimpleQA.
Formatting 500 questions...
Successfully formatted 500 unique questions from SimpleMC.


In [28]:
for i, q in enumerate(qs):
    if q['question'] == 'Which of the three Olympic fencing weapons was the last one to transition to using electrical equipment?':
        print(f"Found question at index {i}: {q}")
        break

Found question at index 36: {'id': 'sqa_test_479', 'question': 'Which of the three Olympic fencing weapons was the last one to transition to using electrical equipment?', 'options': {'A': 'Sabre', 'B': 'Foil', 'C': 'Rapier', 'D': 'Épée'}, 'correct_answer': 'A', 'answer_type': 'Person', 'topic': 'Politics'}


In [30]:
sqa=load_and_format_dataset("SimpleQA")

Attempting to load SimpleQA (test split)...
Dataset loaded successfully.
Formatting 4326 questions...
Successfully formatted 4326 unique questions from SimpleQA.


In [31]:
for i, q in enumerate(sqa):
    if q['question'] == 'Which of the three Olympic fencing weapons was the last one to transition to using electrical equipment?':
        print(f"Found question at index {i}: {q}")
        break

Found question at index 1500: {'id': 'sqa_test_789', 'question': 'Which of the three Olympic fencing weapons was the last one to transition to using electrical equipment?', 'correct_answer': 'Sabre', 'answer_type': 'Other', 'topic': 'Sports'}


In [None]:
import hashlib
import os
import json
def text_to_id(text):
    return "sqa_test_" + hashlib.sha256(text.encode('utf-8')).hexdigest()

## replace every "id" field value in "phase1_questions" and "phase2_questions" with a new id based on the "question" text, and every "question_id" in "results" with a new id based on the "question_text" text
for filename in os.listdir("delegate_game_logs"):
   if "_Simple" in filename and filename.endswith("_game_data.json"):
       fname = os.path.join("delegate_game_logs", filename)
       with open(fname, 'r', encoding='utf-8') as f:
           game_data = json.load(f)
       for q in game_data['phase1_questions']:
           q['id'] = text_to_id(q['question'])
       for q in game_data['phase2_questions']:
           q['id'] = text_to_id(q['question'])
       for q in game_data['results']:
           q['question_id'] = text_to_id(q['question_text'])
       with open(fname, 'w', encoding='utf-8') as f:
           json.dump(game_data, f, ensure_ascii=False, indent=2)

In [43]:
targ_dir = "compiled_results_sqa"
for filename in os.listdir(targ_dir):
   if not "claude-3-5-sonnet-20241022_phase1_compiled.json" in filename:
       continue
   if filename.endswith(".json"):
    fname = os.path.join(targ_dir, filename)
    with open(fname, 'r', encoding='utf-8') as f:
        game_data = json.load(f)
    
    # Create new results dict with updated keys
    new_results = {}
    for old_id, result_data in game_data['results'].items():
        new_id = text_to_id(result_data['question'])
        new_results[new_id] = result_data
    
    # Replace the results dict
    game_data['results'] = new_results
    
    with open(fname, 'w', encoding='utf-8') as f:
        json.dump(game_data, f, ensure_ascii=False, indent=2)

In [41]:
## check for duplicate ids
from collections import Counter
targ_dir = "compiled_results_sqa"
for filename in os.listdir(targ_dir):
   if filename.endswith(".json"):
    fname = os.path.join(targ_dir, filename)
    with open(fname, 'r', encoding='utf-8') as f:
        game_data = json.load(f)
    id_counts = Counter(game_data['results'].keys())
    duplicates = [id for id, count in id_counts.items() if count > 1]
    if duplicates:
        print(f"Duplicate IDs found in {filename}: {duplicates}")
    else:
        print(f"No duplicate IDs found in {filename}")

No duplicate IDs found in gemini-2.0-flash-001_phase1_compiled.json
No duplicate IDs found in deepseek-chat_phase1_compiled.json
No duplicate IDs found in gemini-2.5-flash-preview-04-17_phase1_compiled.json
No duplicate IDs found in grok-3-latest_phase1_compiled.json
No duplicate IDs found in gpt-4o-2024-08-06_phase1_compiled.json
No duplicate IDs found in claude-sonnet-4-20250514_phase1_compiled.json
No duplicate IDs found in claude-3-5-sonnet-20241022_phase1_compiled.json


# Convert PopQA to PopMC

In [1]:
import json
import re
import os
from base_game_class import BaseGameClass

judge_model = "claude-sonnet-4-20250514"
popqa_file = "./data/PopQA_test.jsonl"
popmc_file = "./data/PopMC.jsonl"

# Set how many questions to process (None for all questions)
max_questions = 15000  # Change this to None or a number to process all or a specific amount

# Initialize judge
judge = BaseGameClass(subject_id=None, subject_name=judge_model, is_human_player=False, log_dir=None)

sysprompt = ""
prompt = """I need your help turning a short-answer quiz into a multiple-choice quiz. I'm going to show you a question and a list of possible correct answers. 

First, select the most appropriate answer from the possible_answers list. This will be the correct answer.

Then, generate three distractors. Each distractor should be a plausible answer that is NOT the correct answer. Each should be the same "type" of answer as the correct answer (e.g., date, person name, number, occupation, etc), and follow the format of the correct answer.

Output format:
1. First line: The selected correct answer (exactly as it appears in possible_answers or a close variant)
2. Next three lines: The three distractors, one per line

Do not include any other text. Your entire response should be just the four answers, one per line.

Here is the question and possible answers:

Question: {question}

Possible Answers: {possible_answers}

Selected correct answer and distractors:
"""

# Load existing qids from popmc_file if it exists
existing_qids = set()
if os.path.exists(popmc_file):
    with open(popmc_file, 'r', encoding='utf-8') as f:
        for line in f:
            if line.strip():
                try:
                    existing_data = json.loads(line)
                    existing_qids.add(existing_data['qid'])
                except:
                    continue
    print(f"Found {len(existing_qids)} existing questions in {popmc_file}")

# Open file in append mode with context manager for safety
processed_count = 0
skipped_count = 0
checkpoint_interval = 10  # Print status every N questions

with open(popmc_file, 'a', encoding='utf-8') as fout:
    with open(popqa_file, 'r', encoding='utf-8') as f:
        for ctr, line in enumerate(f):
            # Check if we've reached max_questions
            if max_questions is not None and processed_count >= max_questions:
                print(f"Reached max_questions limit ({max_questions}). Stopping.")
                break
                
            if not line.strip():
                continue
                
            popqa_data = json.loads(line)
            
            # Extract needed fields
            qid = f"popmc_test_{popqa_data['id']}"
            
            # Check if this question already exists
            if qid in existing_qids:
                print(f"Skipping {qid} - already in {popmc_file}")
                skipped_count += 1
                continue
                
            question = popqa_data['question']
            possible_answers_str = popqa_data['possible_answers']
            prop = popqa_data['prop']
            s_pop = popqa_data['s_pop']
            o_pop = popqa_data['o_pop']
            
            # Parse possible_answers (it's a JSON string)
            try:
                possible_answers = json.loads(possible_answers_str)
            except:
                # If it's already a list or fails to parse, try to handle it
                if isinstance(possible_answers_str, list):
                    possible_answers = possible_answers_str
                else:
                    print(f"Warning: Could not parse possible_answers for {qid}. Skipping.")
                    continue
            
            # Format possible_answers as a readable string
            possible_answers_formatted = ", ".join([f'"{ans}"' for ans in possible_answers])
            
            print(f"Processing question {ctr+1}, ID: {qid}")
            
            try:
                while True:
                    resp, _, _ = judge._get_llm_answer(
                        options=None, 
                        q_text=prompt.format(question=question, possible_answers=possible_answers_formatted), 
                        message_history=[], 
                        keep_appending=False, 
                        setup_text=sysprompt, 
                        MAX_TOKENS=None, 
                        temp=1.0
                    )
                    
                    ans_list = re.split(r'\n+', resp) if resp else []
                    ans_list = [a.strip() for a in ans_list if a.strip()]
                    
                    if len(ans_list) >= 4:
                        correct_answer = ans_list[0]
                        distractors = ans_list[1:4]
                        
                        # Check that we have exactly 4 unique answers
                        ans_set = set([a.upper() for a in [correct_answer] + distractors])
                        if len(ans_set) == 4 and all(ans.strip() for ans in distractors):
                            break
                        else:
                            print(f"Invalid response format for question {qid}. Retrying...")
                    else:
                        print(f"Not enough answers returned for question {qid}. Got {len(ans_list)}. Retrying...")
                
                output_data = {
                    "qid": qid,
                    "question": question,
                    "correct_answer": correct_answer,
                    "distractors": distractors,
                    "prop": prop,
                    "s_pop": s_pop,
                    "o_pop": o_pop
                }
                
                # Write immediately and flush to ensure data is saved
                fout.write(json.dumps(output_data, ensure_ascii=False) + "\n")
                fout.flush()  # Force write to disk immediately
                
                # Add to existing_qids immediately to avoid duplicates if we resume
                existing_qids.add(qid)
                processed_count += 1
                
                # Periodic checkpoint message
                if processed_count % checkpoint_interval == 0:
                    print(f"Checkpoint: Processed {processed_count} questions so far...")
                
            except Exception as e:
                print(f"Error processing question {qid}: {e}")
                print(f"Continuing with next question. Progress saved: {processed_count} questions processed.")
                # Continue to next question instead of crashing
                continue

print(f"Conversion complete! Processed {processed_count} new questions, skipped {skipped_count} existing questions. Output written to {popmc_file}")


Provider: OpenRouter
Found 4927 existing questions in ./data/PopMC.jsonl
Skipping popmc_test_4222362 - already in ./data/PopMC.jsonl
Skipping popmc_test_4725190 - already in ./data/PopMC.jsonl
Skipping popmc_test_4382392 - already in ./data/PopMC.jsonl
Skipping popmc_test_4822110 - already in ./data/PopMC.jsonl
Skipping popmc_test_4011112 - already in ./data/PopMC.jsonl
Skipping popmc_test_1730929 - already in ./data/PopMC.jsonl
Skipping popmc_test_276787 - already in ./data/PopMC.jsonl
Skipping popmc_test_1758574 - already in ./data/PopMC.jsonl
Skipping popmc_test_6339290 - already in ./data/PopMC.jsonl
Skipping popmc_test_6250781 - already in ./data/PopMC.jsonl
Skipping popmc_test_1223902 - already in ./data/PopMC.jsonl
Skipping popmc_test_4607081 - already in ./data/PopMC.jsonl
Skipping popmc_test_596555 - already in ./data/PopMC.jsonl
Skipping popmc_test_3994636 - already in ./data/PopMC.jsonl
Skipping popmc_test_394743 - already in ./data/PopMC.jsonl
Skipping popmc_test_4650739 - 

# Convert unfiltered TriviaQA questions to TriviaMC

In [1]:
import json
import os
import re
import time
from base_game_class import BaseGameClass

# Model choice
judge_model = "claude-opus-4-1-20250805"
triviaqa_dev_file = "./data/unfiltered-web-dev.json"
triviamc_file = "./data/TriviaMC.jsonl"
rejections_file = "./data/TriviaQA_fact_checked_rejections.jsonl"
max_questions = None  # None for full run

trivia_judge = BaseGameClass(
    subject_id=None,
    subject_name=judge_model,
    is_human_player=False,
    log_dir=None,
)

# ---------- UNIFIED PROMPT ----------
trivia_sysprompt = ""
unified_prompt = """You are reviewing a trivia question for factual accuracy and converting it into multiple-choice format.

1. **First, check the factual accuracy** of the question and its candidate answers.
   - Be objective, not pedantic.
   - Do not correct or rewrite the question yourself.
   - If there is any potential factual inaccuracy, outdated information, or ambiguity, flag it and explain briefly why.
   - Otherwise, if the question seems factually valid, mark it as OK.

2. **If the question is OK**, generate a multiple-choice question using the original wording:
   - Use one correct answer (the best alias).
   - Create three plausible distractors of the same type.
   - Do not alter the original question text.

Output format:
If valid:
VALID
Correct: <correct answer>
Distractors:
- <distractor 1>
- <distractor 2>
- <distractor 3>

If there are factual concerns:
FLAGGED
Reason: <short explanation why this question may be inaccurate>

Question: {question}
Candidate aliases: {possible_answers}
"""


# ---------- HELPER FUNCTIONS ----------

def collect_aliases(entry):
    """Extract all possible aliases from a TriviaQA-style entry."""
    answer = entry.get("Answer") or {}
    alias_candidates = []
    for key in ("Aliases", "NormalizedAliases"):
        values = answer.get(key) or []
        if isinstance(values, list):
            alias_candidates.extend(values)
    for key in ("Value", "MatchedWikiEntityName", "NormalizedValue"):
        value = answer.get(key)
        if isinstance(value, str):
            alias_candidates.append(value)
    cleaned = []
    for alias in alias_candidates:
        alias = alias.strip()
        if alias and alias not in cleaned:
            cleaned.append(alias)
    return cleaned


# ---------- LOAD EXISTING DATA ----------
existing_qids = set()
existing_questions = set()
rejected_questions = set()
next_trivia_id = 0

if os.path.exists(triviamc_file):
    with open(triviamc_file, "r", encoding="utf-8") as existing_file:
        for line in existing_file:
            if not line.strip():
                continue
            try:
                record = json.loads(line)
            except json.JSONDecodeError:
                continue
            qid = record.get("qid")
            question_text = record.get("question", "").strip().lower()
            if qid:
                existing_qids.add(qid)
                if qid.startswith("triviamc_"):
                    try:
                        idx = int(qid.split("_", 1)[1])
                        next_trivia_id = max(next_trivia_id, idx + 1)
                    except ValueError:
                        pass
            if question_text:
                existing_questions.add(question_text)

if os.path.exists(rejections_file):
    with open(rejections_file, "r", encoding="utf-8") as reject_file:
        for line in reject_file:
            if not line.strip():
                continue
            try:
                record = json.loads(line)
            except json.JSONDecodeError:
                continue
            qid = record.get("qid")
            question_text = record.get("question", "").strip().lower()
            if qid and qid.startswith("triviamc_"):
                try:
                    idx = int(qid.split("_", 1)[1])
                    next_trivia_id = max(next_trivia_id, idx + 1)
                except ValueError:
                    pass
            if question_text:
                rejected_questions.add(question_text)

print(f"Resuming from triviamc_{next_trivia_id}")
print(f"Found {len(existing_qids)} accepted and {len(existing_questions)} unique questions in {triviamc_file}")
print(f"Found {len(rejected_questions)} rejected questions in {rejections_file}")

# ---------- LOAD SOURCE DATA ----------
with open(triviaqa_dev_file, "r", encoding="utf-8") as f:
    data = json.load(f)

dev_questions = data.get("Data", data) if isinstance(data, dict) else data
print(f"Loaded {len(dev_questions)} raw trivia questions")

# ---------- PROCESS ----------
processed, skipped = 0, 0
checkpoint = 10

with open(triviamc_file, "a", encoding="utf-8") as trivia_out, \
     open(rejections_file, "a", encoding="utf-8") as reject_out:

    for entry in dev_questions:
        if max_questions and processed + skipped >= max_questions:
            print(f"Reached max_questions={max_questions}. Stopping.")
            break

        aliases = collect_aliases(entry)
        question_text = (entry.get("Question") or "").strip()
        if not aliases or not question_text:
            skipped += 1
            continue

        normalized_q = question_text.lower()
        if normalized_q in existing_questions:
            print(f"Skipping duplicate: {question_text[:80]}")
            skipped += 1
            continue

        if normalized_q in rejected_questions:
            print(f"Skipping previously rejected: {question_text[:80]}")
            skipped += 1
            continue

        # Always increment QID — even if rejected, we log it
        qid = f"triviamc_{next_trivia_id}"
        next_trivia_id += 1
        possible_answers = ", ".join([f'"{a}"' for a in aliases])

        print(f"\nProcessing {qid}: {question_text[:80]}...")

        # Retry logic for rate limit errors
        max_retries = 5
        retry_delay = 2  # Start with 2 seconds
        response = None
        
        for attempt in range(max_retries):
            try:
                response, _, _ = trivia_judge._get_llm_answer(
                    options=None,
                    q_text=unified_prompt.format(
                        question=question_text,
                        possible_answers=possible_answers,
                    ),
                    message_history=[],
                    keep_appending=False,
                    setup_text=trivia_sysprompt,
                    MAX_TOKENS=None,
                    temp=0.7,
                )
                
                if response:
                    break  # Success, exit retry loop
                    
            except Exception as e:
                error_str = str(e)
                # Check if it's a rate limit error
                if "429" in error_str or "rate" in error_str.lower() or "RateLimitError" in error_str:
                    if attempt < max_retries - 1:
                        wait_time = retry_delay * (2 ** attempt)  # Exponential backoff
                        print(f"⚠️ Rate limit hit for {qid}. Waiting {wait_time}s before retry {attempt + 1}/{max_retries}...")
                        time.sleep(wait_time)
                        continue
                    else:
                        print(f"⚠️ Rate limit error for {qid} after {max_retries} attempts. Skipping.")
                        skipped += 1
                        break
                else:
                    # Non-rate-limit error, re-raise to be handled by outer try-except
                    raise
        
        if not response:
            print(f"⚠️ No response for {qid}. Skipping.")
            skipped += 1
            continue
        
        try:

            resp = response.strip()

            if resp.startswith("FLAGGED"):
                # Extract reason and log to rejections file
                m = re.search(r"Reason:\s*(.*)", resp, re.I)
                reason = m.group(1).strip() if m else "Unspecified concern"
                print(f"⚠️ Flagged {qid}: {reason}")
                reject_out.write(json.dumps({
                    "qid": qid,
                    "question": question_text,
                    "aliases": aliases,
                    "reason": reason
                }, ensure_ascii=False) + "\n")
                reject_out.flush()
                skipped += 1
                # Small delay to avoid hitting rate limits
                time.sleep(0.5)
                continue

            elif resp.startswith("VALID"):
                # Same parsing logic as before, but do NOT rewrite question
                c_match = re.search(r"Correct:\s*(.+)", resp)
                d_matches = re.findall(r"-\s*(.+)", resp)
                if c_match and len(d_matches) >= 3:
                    correct_answer = c_match.group(1).strip()
                    distractors = d_matches[:3]
                else:
                    print(f"⚠️ Invalid format for {qid}. Skipping.")
                    skipped += 1
                    continue

                trivia_out.write(json.dumps({
                    "qid": qid,
                    "question": question_text,  # keep original
                    "correct_answer": correct_answer,
                    "distractors": distractors
                }, ensure_ascii=False) + "\n")


             
                trivia_out.flush()
                existing_qids.add(qid)
                existing_questions.add(question_text.lower())
                processed += 1

                if processed % checkpoint == 0:
                    print(f"Checkpoint: {processed} accepted questions so far.")
                
                # Small delay to avoid hitting rate limits
                time.sleep(0.5)

            else:
                print(f"⚠️ Unrecognized response format for {qid}. Skipping.")
                skipped += 1

        except Exception as e:
            print(f"Error processing {qid}: {e}")
            skipped += 1
            continue

print(f"\n✅ Conversion complete. Accepted {processed}, rejected {skipped}.")
print(f"Accepted written to {triviamc_file}")
print(f"Rejected logged to {rejections_file}")



Provider: OpenRouter
Resuming from triviamc_2883
Found 2403 accepted and 2403 unique questions in ./data/TriviaMC.jsonl
Found 302 rejected questions in ./data/TriviaQA_fact_checked_rejections.jsonl
Loaded 11313 raw trivia questions
Skipping duplicate: Who was the man behind The Chipmunks?
Skipping duplicate: What star sign is Jamie Lee Curtis?
Skipping duplicate: Which Lloyd Webber musical premiered in the US on 10th December 1993?
Skipping duplicate: Who was the next British Prime Minister after Arthur Balfour?
Skipping previously rejected: Who had a 70s No 1 hit with Kiss You All Over?
Skipping duplicate: What claimed the life of singer Kathleen Ferrier?
Skipping duplicate: Rita Coolidge sang the title song for which Bond film?
Skipping previously rejected: To the nearest million what is the population of Australia?
Skipping previously rejected: What was the last US state to reintroduce alcohol after prohibition?
Skipping duplicate: Which actress was voted Miss Greenwich Village in 1

KeyboardInterrupt: 

Provider that responded: Google


In [4]:
import json
f1 = "./delegate_game_logs/claude-3-5-sonnet-20241022_GPSA_50_450_team0.7_temp0.0_1749479584_game_data_evaluated.json"
f2 = "./delegate_game_logs/claude-3-5-sonnet-20241022_GPSA_50_450_nohistory_summary_team0.5_temp0.0_1749559243_game_data_evaluated.json"
with open(f1, 'r', encoding='utf-8') as f:
    game_data_hist = json.load(f)
with open(f2, 'r', encoding='utf-8') as f:
    game_data_nohist = json.load(f)

diff_del_list = []
diff_choice_list = []
for trial in game_data_hist['results']:
    qid = trial['question_id']
    for trial_nohist in game_data_nohist['results']:
        if trial_nohist['question_id'] == qid:
            if trial['delegation_choice'] != trial_nohist['delegation_choice']:
                diff_del_list.append((trial, trial_nohist))
            elif (trial['delegation_choice']=="Self" and trial['subject_correct'] != trial_nohist['subject_correct']):
                diff_choice_list.append((trial, trial_nohist))
            break
print(f"Found {len(diff_del_list)} trials with delegation differences between history and no history versions.")
print(f"Found {len(diff_choice_list)} trials with choice differences between history and no history versions.")


FileNotFoundError: [Errno 2] No such file or directory: './delegate_game_logs/claude-3-5-sonnet-20241022_GPSA_50_450_team0.7_temp0.0_1749479584_game_data_evaluated.json'

In [51]:
diff_choice_list[6]

({'subject_id': 'claude-3-5-sonnet-20241022_GPSA_50_450_team0.7_temp0.0',
  'phase': 2,
  'trial_in_phase': 193,
  'question_id': 'gpqa_train_recjgMJaMxz4ESDF2',
  'question_text': "Compounds that have the same molecular formula but are different in their structural arrangement are known as isomers. Isomers have two types, constitutional isomers and stereoisomers. Constitutional isomers have the same molecular formula but differ in their structures. In stereoisomers, molecules are connected in the same way, but their arrangements in space are different.\nWhich of the following organic moieties show optical isomerism?\n\n1. dimethyl 6,6'-dinitro-[1,1'-biphenyl]-2,2'-dicarboxylate\n2. methyl 2-hydroxypropanoate\n3. benzophenone\n4. dimethyl fumarate",
  'correct_answer': '1 and 2',
  'timestamp': 1749480176.218984,
  'subject_answer': '2 (methyl 2-hydroxypropanoate)\n\nThis compound has a chiral carbon with four different substituents: -H, -OH, -CH3, and -COOCH3.',
  'subject_correct': F

In [None]:
def contingency(delegate: np.ndarray, correct: np.ndarray):
    """
    delegate : bool[N]   True -> model delegated
    correct  : bool[N]   True -> model would be correct on its own
    returns  : TP, FN, FP, TN as ints
    """
    TP = np.sum(delegate  & ~correct)   # delegate & wrong
    FN = np.sum(~delegate & ~correct)   # keep     & wrong
    FP = np.sum(delegate  &  correct)   # delegate & right
    TN = np.sum(~delegate &  correct)   # keep     & right
    return TP, FN, FP, TN

def lift_mcc_stats(tp, fn, fp, tn, p0, n_boot=2000, seed=0):
    """
    Parameters
    ----------
    tp, fn, fp, tn : int
        Contingency counts on Phase-2 items  
            tp = delegate & wrong  
            fn = keep & wrong  
            fp = delegate & right  
            tn = keep & right
    p0 : float
        Baseline accuracy to test against (global for RAW, hybrid value for HYBRID)
    Returns
    -------
    dict with point estimates, CIs, and p-values for
        lift   = acc_kept - p0
        mcc    = Matthews correlation
    """
    rng = np.random.default_rng(seed)

    # ---------- point estimates --------------------------------------------
    k         = fn + tn                       # kept items
    kept_acc  = tn / k if k else np.nan
    lift      = kept_acc - p0

    denom = math.sqrt((tp+fp)*(tp+fn)*(tn+fp)*(tn+fn))
    mcc   = (tp*tn - fp*fn) / denom if denom else np.nan

    # ---------- p-values ----------------------------------------------------
    p_lift = binomtest(tn, k, p0, alternative='two-sided').pvalue
    p_mcc  = mcnemar([[tn, fp],
                      [fn, tp]], exact=True).pvalue   # two-sided by default

    # ---------- bootstrap CIs ----------------------------------------------
    N        = tp + fn + fp + tn
    counts   = np.array([tp, fn, fp, tn], int)
    probs    = counts / N

    lifts, mccs = [], []
    for _ in range(n_boot):
        sample = rng.choice(4, size=N, replace=True, p=probs)
        btp, bfn, bfp, btn = np.bincount(sample, minlength=4)

        bk     = bfn + btn
        b_acc  = btn / bk if bk else 0.0
        lifts.append(b_acc - p0)

        bden = math.sqrt((btp+bfp)*(btp+bfn)*(btn+bfp)*(btn+bfn))
        bmcc = (btp*btn - bfp*bfn) / bden if bden else 0.0
        mccs.append(bmcc)

    ci_lift = np.percentile(lifts, [2.5, 97.5])
    ci_mcc  = np.percentile(mccs,  [2.5, 97.5])

    return dict(
        lift       = lift,
        lift_ci    = tuple(ci_lift),
        p_lift     = p_lift,
        mcc        = mcc,
        mcc_ci     = tuple(ci_mcc),
        p_mcc      = p_mcc
    )

...

delegated = np.array(df_model['delegate_choice'], bool)
kept_mask = ~delegated                       # True where model answered itself
cap_corr = np.array(df_model['s_i_capability'], int)   # Baseline correctness from capabilities file
team_corr = np.where(df_model['delegate_choice'] == 0, df_model['team_correct'].fillna(0).astype(int), 0).astype(int) #Real in-game self correctness (only defined when kept)
# Hybrid correctness label 
#    – use real game correctness when the model kept
#    – fallback to baseline correctness when it delegated
true_label = np.where(kept_mask, team_corr, cap_corr)   # 1 = model would be correct

TP, FN, FP, TN = contingency(delegated, cap_corr)
raw_stats = lift_mcc_stats(TP, FN, FP, TN, cap_corr.mean())
log_output(f"Introspection score = {raw_stats['mcc']:.3f} [{raw_stats['mcc_ci'][0]:.3f}, {raw_stats['mcc_ci'][1]:.3f}], p={raw_stats['p_mcc']:.4g}")
delta_d, ci_low, ci_high, p_val = delegate_gap_stats(TP=TP, FN=FN, FP=FP, TN=TN)
log_output(f"Delegate Gap = {delta_d:.3f} [{ci_low:.3f}, {ci_high:.3f}, p={p_val:.4g}]")

TP, FN, FP, TN = contingency(delegated, true_label)
N = (TP+FP+TN+FN)
k   = FN + TN
acc_kept   = TN / k
acc_deleg  = cap_corr[delegated].mean()
p0_hyb     = (k/N)*acc_kept + (1-k/N)*acc_deleg
adj_stats = lift_mcc_stats(TP, FN, FP, TN, p0_hyb)

log_output(f"Adjusted introspection score = {adj_stats['mcc']:.3f} [{adj_stats['mcc_ci'][0]:.3f}, {adj_stats['mcc_ci'][1]:.3f}], p={adj_stats['p_mcc']:.4g}")
delta_d, ci_low, ci_high, p_val = delegate_gap_stats(TP=TP, FN=FN, FP=FP, TN=TN)
log_output(f"Adjusted delegate gap = {delta_d:.3f} [{ci_low:.3f}, {ci_high:.3f}, p={p_val:.4g}]")

log_output(f"Self-acc lift = {raw_stats['lift']:.3f} [{raw_stats['lift_ci'][0]:.3f}, {raw_stats['lift_ci'][1]:.3f}], p={raw_stats['p_lift']:.4g}")

log_output(f"Adjusted self-acc lift = {adj_stats['lift']:.3f} [{adj_stats['lift_ci'][0]:.3f}, {adj_stats['lift_ci'][1]:.3f}], p={adj_stats['p_lift']:.4g}")
