In [20]:
import matplotlib.pyplot as plt
import numpy as np
import json
import os

In [26]:
def load_data(path):
    with open(path, 'r') as f_:
        return json.load(f_)

# Load the JSON from results folder
results = load_data("results/graph_Qwen3-4B_results.json")

In [27]:
# Define the model and dataset [dataset = ["graph", "mmlu_med"]]
model = "Qwen/Qwen3-4B"
dataset = "graph"

In [28]:
def compute_iteration_metrics(results, max_iter=None):
    # Determine max iterations present
    if max_iter is None:
        max_iter = 0
        for r in results:
            n = len(r["iterations"])
            if n > max_iter:
                max_iter = n

    num_examples = len(results)
    iteration_accuracy = []
    best_so_far_accuracy = []
    mat = np.full((num_examples, max_iter), np.nan)

    # Fill Score at each iteration
    for ei, r in enumerate(results):
        for it in r["iterations"]:
            idx = it["iteration"] - 1
            if 0 <= idx < max_iter:
                mat[ei, idx] = it.get("score") if bool(it.get("score")) else 0.0

    # Iteration Accuracy and Best So Far
    best_so_far = np.zeros((num_examples,), dtype=bool)
    for i in range(max_iter):
        col = mat[:, i]
        valid_mask = ~np.isnan(col)
        if valid_mask.sum() == 0:
            iteration_accuracy.append(None)
        else:
            iteration_accuracy.append(float(np.nanmean(col)))
        # update best_so_far for each example if this iteration is correct (==1.0)
        is_correct = (col == 1.0)
        is_correct = np.where(np.isnan(col), False, is_correct)
        best_so_far = best_so_far | is_correct
        best_so_far_accuracy.append(float(np.mean(best_so_far)))

    # Compute conditional probabilities P(correct_{i+1} | correct_i) and P(correct_{i+1} | incorrect_i)
    conditional = {}
    for i in range(max_iter - 1):
        a = mat[:, i]       # iteration i
        b = mat[:, i + 1]   # iteration i+1
        valid = (~np.isnan(a)) & (~np.isnan(b))
        valid_count = int(valid.sum())
        if valid_count == 0:
            conditional[i + 1] = {
                "valid_count": 0,
                "p_correct_next_given_correct": None,
                "p_correct_next_given_incorrect": None,
                "counts": {"correct_to_correct": 0, "correct_to_incorrect": 0, "incorrect_to_correct": 0, "incorrect_to_incorrect": 0}
            }
            continue

        a_valid = a[valid]
        b_valid = b[valid]

        # boolean arrays
        # Using Naive Bayes P(B|A) = P(B intersection A) / P(A)
        # Count Probability Correct(i+1) | Correct(i) and Correct(i+1) | Incorrect(i)
        a_is_correct = (a_valid == 1.0)
        b_is_correct = (b_valid == 1.0)

        correct_to_correct = int(np.sum(a_is_correct & b_is_correct))
        correct_to_incorrect = int(np.sum(a_is_correct & (~b_is_correct)))
        incorrect_to_correct = int(np.sum((~a_is_correct) & b_is_correct))
        incorrect_to_incorrect = int(np.sum((~a_is_correct) & (~b_is_correct)))

        num_correct = int(np.sum(a_is_correct))
        num_incorrect = int(np.sum(~a_is_correct))

        p_cor_next_given_cor = (correct_to_correct / num_correct) if num_correct > 0 else None
        p_cor_next_given_inc = (incorrect_to_correct / num_incorrect) if num_incorrect > 0 else None

        conditional[i + 1] = {
            "valid_count": valid_count,
            "p_correct_next_given_correct": p_cor_next_given_cor,
            "p_correct_next_given_incorrect": p_cor_next_given_inc,
            "counts": {
                "correct_to_correct": correct_to_correct,
                "correct_to_incorrect": correct_to_incorrect,
                "incorrect_to_correct": incorrect_to_correct,
                "incorrect_to_incorrect": incorrect_to_incorrect,
                "num_correct": num_correct,
                "num_incorrect": num_incorrect,
            }
        }

    return {
        "max_iter": max_iter,
        "num_examples": num_examples,
        "iteration_accuracy": iteration_accuracy,
        "best_so_far_accuracy": best_so_far_accuracy,
        "conditional": conditional,
        "mat": mat
    }

In [29]:
def plot_metrics(metrics, output_dir, model_name, dataset_name):
    os.makedirs(output_dir, exist_ok=True)
    max_iter = metrics["max_iter"]

    iters = list(range(1, max_iter + 1))
    iter_acc = metrics["iteration_accuracy"]
    best_acc = metrics["best_so_far_accuracy"]

    plt.figure(figsize=(8, 5))
    
    # iteration accuracy
    y_iter = [v if v is not None else np.nan for v in iter_acc]
    plt.plot(iters, y_iter, marker='o', label='Iteration accuracy')
    
    # best-so-far
    y_best = [v if v is not None else np.nan for v in best_acc]
    plt.plot(iters, y_best, marker='o', linestyle='--', label='Best-so-far accuracy')
    plt.xlabel("Iteration")
    plt.ylabel("Accuracy")
    plt.title(f"{dataset_name} — Accuracy per iteration (model={model_name})")
    plt.xticks(iters)
    plt.ylim(0.0, 1.0)
    plt.grid(True, linestyle=':', alpha=0.5)
    plt.legend()
    lineplot_path = os.path.join(output_dir, f"{dataset_name}_{model_name}_accuracy_iterations.png")
    plt.tight_layout()
    plt.savefig(lineplot_path)
    plt.close()

    # Bar plot for conditional probabilities
    transitions = []
    p_corr_given_corr = []
    p_corr_given_inc = []
    valid_counts = []
    for t in sorted(metrics["conditional"].keys()):
        transitions.append(f"{t}->{t+1}")
        info = metrics["conditional"][t]
        p_corr_given_corr.append(info["p_correct_next_given_correct"] if info["p_correct_next_given_correct"] is not None else np.nan)
        p_corr_given_inc.append(info["p_correct_next_given_incorrect"] if info["p_correct_next_given_incorrect"] is not None else np.nan)
        valid_counts.append(info["valid_count"])

    x = np.arange(len(transitions))
    width = 0.33

    plt.figure(figsize=(10, 6))
    plt.bar(x - width/2, p_corr_given_corr, width, label='P(correct_{i+1} | correct_i)')
    plt.bar(x + width/2, p_corr_given_inc, width, label='P(correct_{i+1} | incorrect_i)')
    plt.xlabel("Transition (iteration -> next)")
    plt.ylabel("Probability")
    plt.title(f"{dataset_name} — Conditional improvement probabilities (model={model_name})")
    plt.xticks(x, transitions)
    plt.ylim(0.0, 1.0)
    plt.grid(axis='y', linestyle=':', alpha=0.5)
    plt.legend()
    
    for xi, (pc, pi, vc) in enumerate(zip(p_corr_given_corr, p_corr_given_inc, valid_counts)):
        plt.text(xi - width/2, pc + 0.02, f"{pc:.2f}", ha='center', va='bottom', fontsize=9)
        plt.text(xi + width/2, pi + 0.02, f"{pi:.2f}", ha='center', va='bottom', fontsize=9)
        plt.text(xi, 0.02, f"n={vc}", ha='center', va='bottom', fontsize=8, color='gray')

    barplot_path = os.path.join(output_dir, f"{dataset_name}_{model_name}_conditional_probs.png")
    plt.tight_layout()
    plt.savefig(barplot_path)
    plt.close()

    return {"lineplot": lineplot_path, "barplot": barplot_path}

In [30]:
metrics = compute_iteration_metrics(results)
print(f"Computed metrics for max_iter = {metrics['max_iter']}, num_examples = {metrics['num_examples']}")

# iteration accuracies
for idx, acc in enumerate(metrics["iteration_accuracy"]):
    print(f"Iteration {idx} accuracy: {acc}")

# best-so-far
for idx, acc in enumerate(metrics["best_so_far_accuracy"]):
    print(f"Best-so-far up to iteration {idx + 1}: {acc}")

# conditional probabilities
print("\n\nConditional probabilities P(correct_{i+1} | correct_i) and P(correct_{i+1} | incorrect_i):")
for t, info in metrics["conditional"].items():
    print(f"Transition {t}->{t+1}: valid_n={info['valid_count']}, "
        f"P(next|correct)={info['p_correct_next_given_correct']}, "
        f"P(next|incorrect)={info['p_correct_next_given_incorrect']} -- counts: {info['counts']}")

plot_paths = plot_metrics(metrics, "xyz", model.replace('/', '_'), dataset)

Computed metrics for max_iter = 4, num_examples = 100
Iteration 0 accuracy: 0.72
Iteration 1 accuracy: 0.49
Iteration 2 accuracy: 0.63
Iteration 3 accuracy: 0.55
Best-so-far up to iteration 1: 0.72
Best-so-far up to iteration 2: 0.78
Best-so-far up to iteration 3: 0.82
Best-so-far up to iteration 4: 0.82


Conditional probabilities P(correct_{i+1} | correct_i) and P(correct_{i+1} | incorrect_i):
Transition 1->2: valid_n=100, P(next|correct)=0.5972222222222222, P(next|incorrect)=0.21428571428571427 -- counts: {'correct_to_correct': 43, 'correct_to_incorrect': 29, 'incorrect_to_correct': 6, 'incorrect_to_incorrect': 22, 'num_correct': 72, 'num_incorrect': 28}
Transition 2->3: valid_n=100, P(next|correct)=0.8775510204081632, P(next|incorrect)=0.39215686274509803 -- counts: {'correct_to_correct': 43, 'correct_to_incorrect': 6, 'incorrect_to_correct': 20, 'incorrect_to_incorrect': 31, 'num_correct': 49, 'num_incorrect': 51}
Transition 3->4: valid_n=100, P(next|correct)=0.7301587301587301, P