Heuristically counting overcommitment/undercommitment rate from the experiments.

In [None]:
from pathlib import Path

from redel.utils import read_jsonl

# define base experiments path
EXPERIMENTS = Path("/Users/andrew/Desktop/Code/kanpai/experiments")


def is_overcommitted(fp, overcommitment_threshold):
    """A system is overcommitted if it has <= overcommitment_threshold nodes"""
    with open(fp) as f:
        state = json.load(f)
    return len(state["state"]) <= overcommitment_threshold


def is_undercommitted(fp, undercommitment_threshold):
    """A system is undercommitted if it has any undercommitment_threshold len chain of nodes with 0 or 1 children"""
    with open(fp) as f:
        state = json.load(f)

    nodes = {node["id"]: node for node in state["state"]}
    root = next(node for node in state["state"] if node["depth"] == 0)

    # DFS into each node, when reaching leaf mark as T/F if accumulated 1-child parents >= undercommitment_threshold
    # then every node's T/F value = any(children)
    # return root node's value
    def uc_search(node, chain):
        is_chain = len(node["children"]) <= 1
        if not node["children"]:
            return chain + 1 >= undercommitment_threshold

        is_uc = False
        for child_id in node["children"]:
            child = nodes[child_id]
            is_uc = is_uc or uc_search(child, chain + 1 if is_chain else 0)
        return is_uc

    return uc_search(root, 0)

In [None]:
import json
from dataclasses import dataclass


@dataclass
class CommitmentResult:
    oc_count: int
    uc_count: int
    samples: int
    oc_ids: list[str]
    uc_ids: list[str]


def count_system(fp, overcommitment_threshold=2, undercommitment_threshold=3):
    if not fp.exists():
        return
    # get all state paths in system
    state_paths = []
    for result in read_jsonl(fp / "results.jsonl"):
        state_paths.append(fp / Path(result["log_dir"]).stem / "state.json")

    n = len(state_paths)
    oc_count = 0
    uc_count = 0
    oc_ids = []
    uc_ids = []
    for state_path in state_paths:
        if is_overcommitted(state_path, overcommitment_threshold):
            oc_count += 1
            oc_ids.append(state_path.parent.name)
        if is_undercommitted(state_path, undercommitment_threshold):
            uc_count += 1
            uc_ids.append(state_path.parent.name)

    print(f"========== {fp} ==========")
    print(f"Overcommitment rate: {oc_count / n} ({oc_count} / {n})")
    print(f"Undercommitment rate: {uc_count / n} ({uc_count} / {n})")
    return CommitmentResult(oc_count=oc_count, uc_count=uc_count, samples=n, oc_ids=oc_ids, uc_ids=uc_ids)

In [None]:
fo = count_system(EXPERIMENTS / Path("fanoutqa/dev/trial2/full"))
tp = count_system(EXPERIMENTS / Path("travelplanner/validation/full"))
wa = count_system(EXPERIMENTS / Path("webarena/test/full"))

In [None]:
oc_total = fo.oc_count + tp.oc_count + wa.oc_count
uc_total = fo.uc_count + tp.uc_count + wa.uc_count
n_total = fo.samples + tp.samples + wa.samples

print(f"Total overcommitment rate: {oc_total / n_total} ({oc_total} / {n_total})")
print(f"Total undercommitment rate: {uc_total / n_total} ({uc_total} / {n_total})")

In [None]:
for system in [
    "full",
    # "root-fc",
    # "baseline",
    # "small-leaf",
    "small-all",
    # "small-baseline",
    # "short-context",
    # "short-baseline",
]:
    for benchmark in ["fanoutqa/dev/trial2", "travelplanner/validation", "webarena/test"]:
        count_system(EXPERIMENTS / benchmark / system)

Getting score conditional on over/undercommitted results.

In [None]:
# foqa
benchmark = "fanoutqa/dev/trial2"
system = "full"

with open(EXPERIMENTS / benchmark / system / "score.json") as f:
    fo_scores = json.load(f)

commitment = count_system(EXPERIMENTS / benchmark / system)
bad_ids = set(commitment.uc_ids) | set(commitment.uc_ids)

good_scores = [s for s in fo_scores["raw"] if s["question_id"] in bad_ids]

good_loose = sum(s["acc"] for s in good_scores) / len(good_scores)
good_gpt = sum(s["gpt"] for s in good_scores) / len(good_scores)

print("========== FOQA ==========")
print(f"Full Loose: {fo_scores['acc']['loose']}")
print(f"Full GPT: {fo_scores['gpt']}")
print(f"Filtered Loose: {good_loose}")
print(f"Filtered GPT: {good_gpt}")

# for system in ["full", "small-leaf"]:

In [None]:
len(bad_ids)

In [None]:
bad_ids

In [None]:
# foqa
benchmark = "fanoutqa/dev/trial2"
system = "full"

with open(EXPERIMENTS / benchmark / system / "score.json") as f:
    fo_scores = json.load(f)

fails = [s for s in fo_scores["raw"] if s["gpt"] == 0]
fail_ids = [s["question_id"] for s in fails]


# for system in ["full", "small-leaf"]:

In [None]:
fails