In [6]:
import numpy as np
from typing import List

from src.math_dataset import MATHQuestion, eval_model_answers
from src.llm_client import call_many_with_prefix, get_few_shot_prompt, format_answer
from src.utils import choose_indices
from src.config import MODELS_LIST
from src.math_dataset import load_questions

In [24]:
async def multi_turn_supervision(
    weak_model: str,
    strong_model: str,
    train_dataset: List[MATHQuestion],
    test_dataset: List[MATHQuestion],
    indices: List[int],
    n: int = 1
) -> float:
    # build few-shot using weak labels on selected train examples
    logs = {}
    few_shot_questions = [train_dataset[i] for i in indices]
    weak_prompts = [q.get_prompt() for q in few_shot_questions]
    weak_responses = await call_many_with_prefix([], weak_prompts, model=weak_model)
    weak_answers = [format_answer(r.choices[0].message.content) for r in weak_responses]
    logs["few_shot_init_examples"] = [(q.get_prompt(), a) for q, a in zip(few_shot_questions, weak_answers)]
    print("Weak model answered")
    few_shot_prompt = get_few_shot_prompt(
        [(q.get_prompt(), a) for q, a in zip(few_shot_questions, weak_answers)]
    )
    # 1) Strong model generates initial answers y0 for all test examples
    test_prompts = [q.get_prompt() for q in test_dataset]
    strong_responses = await call_many_with_prefix(
        few_shot_prompt, test_prompts, model=strong_model
    )
    strong_answers = [format_answer(r.choices[0].message.content) for r in strong_responses]
    logs["initial_strong_answers"] = strong_answers
    for i in range(n):
        print(f"--- Iteration {i+1} ---")
        # 2) Weak model critiques each initial answer
        critique_prompts = [
            f"{q.get_prompt()}\n\nPrevious answer:\n{ans}\n\nPlease provide a short critique: point out incorrect steps or wrong final answer and give concise feedback."
            for q, ans in zip(test_dataset, strong_answers)
        ]
        print("Generating critiques with weak model...")
        weak_critique_responses = await call_many_with_prefix([], critique_prompts, model=weak_model)
        critiques = [format_answer(r.choices[0].message.content) for r in weak_critique_responses]
        logs[f"round_{i+1}_critiques"] = critiques
        print("Weak model critiqued")
        # 3) Strong model revises using previous answer + reviewer feedback
        revise_prompts = [
            (
                f"{q.get_prompt()}\n\nPrevious answer:\n{prev}\n\nReviewer feedback:\n{fb}\n\n"
                "Please revise the previous answer, fix any mistakes, and provide a corrected final answer with reasoning."
            )
            for q, prev, fb in zip(test_dataset, strong_answers, critiques)
        ]
        print("Revising answers with strong model...")
        strong_responses = await call_many_with_prefix(few_shot_prompt, revise_prompts, model=strong_model)
        strong_answers = [format_answer(r.choices[0].message.content) for r in strong_responses]
        logs[f"round_{i+1}_revised_strong_answers"] = strong_answers
        print("Strong model revised")
        print("current accuracy:", float(np.mean(eval_model_answers(test_dataset, strong_answers))))

    acc = float(np.mean(eval_model_answers(test_dataset, strong_answers)))

    return logs, acc

In [25]:
train_dataset = load_questions("train")
test_dataset = load_questions("test")

In [26]:
indices = choose_indices(len(train_dataset), 3, seed=42)

In [27]:
records = {}
for n in [1, 2, 3]:
    print(f"=== Multi-turn supervision with {n} rounds ===")
    logs, test_acc = await multi_turn_supervision(weak_model="gpt-4o-mini", strong_model="gpt-4.1-mini", train_dataset=train_dataset, test_dataset=test_dataset[:50], indices=indices, n=n)
    records[n] = {
        "logs": logs,
        "test_accuracy": test_acc
    }
    print(f"Final test accuracy after {n} rounds: {test_acc}\n")


=== Multi-turn supervision with 1 rounds ===
Weak model answered
--- Iteration 1 ---
Generating critiques with weak model...
Weak model critiqued
Revising answers with strong model...
Strong model revised
current accuracy: 0.92
Final test accuracy after 1 rounds: 0.92

=== Multi-turn supervision with 2 rounds ===
Weak model answered
--- Iteration 1 ---
Generating critiques with weak model...
Weak model critiqued
Revising answers with strong model...
Strong model revised
current accuracy: 0.92
--- Iteration 2 ---
Generating critiques with weak model...
Weak model critiqued
Revising answers with strong model...
Strong model revised
current accuracy: 0.9
Final test accuracy after 2 rounds: 0.9

=== Multi-turn supervision with 3 rounds ===
Weak model answered
--- Iteration 1 ---
Generating critiques with weak model...
Weak model critiqued
Revising answers with strong model...
Strong model revised
current accuracy: 0.88
--- Iteration 2 ---
Generating critiques with weak model...
Weak model 