# **C**hain of **A**gent with chain of **V**erific**A**tion (CAVA)


## 1) Setup & Dependencies

In [1]:
!pip install -qU langchain langchain-core langchain-text-splitters langchain-community langchain-google-genai

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/475.9 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m475.9/475.9 kB[0m [31m23.0 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/2.5 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.5/2.5 MB[0m [31m93.5 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/63.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.6/63.6 kB[0m [31m5.7 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/1.0 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.0/1.0 MB[0m [31m59.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m

In [2]:
# Drive (for logs)
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
# Core deps
import os, json, random, re, string
from typing import List, Dict, Callable, TypedDict

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from langchain_core.runnables import RunnableLambda
from langchain_text_splitters import RecursiveCharacterTextSplitter

from datasets import load_dataset
from tqdm import tqdm

# Disable grad for torch
torch.set_grad_enabled(False)

# Gemini (optional)
import google.generativeai as genai
from google.colab import userdata
os.environ["GOOGLE_API_KEY"] = userdata.get('GOOGLE_API_KEY')
genai.configure(api_key=os.environ["GOOGLE_API_KEY"])

## 2) Experiment Config

In [4]:
# ---- Generation & Chunking ----
MAX_NEW_TOKENS = 128
CHUNK_SIZE = 2000

# ---- Models ----
LOCAL_MODEL_NAME = "Qwen/Qwen2.5-3B-Instruct"
GEMINI_MODEL_NAME = "gemini-2.5-flash-lite"

# ---- CoVe ----
VERIFICATION_MODE = "every_k" # "none" | "every" | "every_k"
VERIFICATION_K = 3

# ---- Logging ----
LOG_PATH = "/content/drive/MyDrive/GenAI/project/logs/hotpotqa_qwen3B_every3_test.jsonl"
FLUSH_EVERY = 5
INCLUDE_CONTEXT_IN_LOG = False

# ---- Dataset selection ----
SUBSET_START = 0
SUBSET_END = 10
NUM_SAMPLES_TO_LOAD = SUBSET_END - SUBSET_START
SEED = 130

# balanced sampling (easy/medium)
LEVELS = ("easy", "medium")
NUM_SAMPLES_TO_LOAD_PER_LEVEL = 5

## 3) LLM Backends

In [5]:
# Function to create LLM
def load_local_llm(model_id, max_new_tokens=MAX_NEW_TOKENS):
    tokenizer = AutoTokenizer.from_pretrained(model_id)

    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        device_map="auto",
        dtype=torch.float16,
    )
    model.eval()
    def _generate(prompt: str) -> str:
        # Tokenize input
        inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
        # Call model
        with torch.inference_mode():
            out = model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                do_sample=False,
                use_cache=True,
            )
        # Remove input prompt from final output
        gen_ids = out[0, inputs["input_ids"].shape[1]:]
        text = tokenizer.decode(gen_ids, skip_special_tokens=True)
        return text.strip()
    # Return class so we can call llm.invoke()
    return RunnableLambda(lambda x: _generate(x))


# Use Gemini, we use the same runnable to make it able to call llm.invoke()
def load_google_llm(model_name: str, max_new_tokens: int = 256):
    model = genai.GenerativeModel(model_name)

    def _generate(prompt: str) -> str:
        response = model.generate_content(
            prompt,
            generation_config={
                "max_output_tokens": max_new_tokens,
                "temperature": 0.0,   # deterministic for QA
                "top_p": 1.0,
            },
        )
        # response.text is already the concatenated text of all parts
        return (response.text or "").strip()

    return RunnableLambda(lambda x: _generate(x))


# Loading model
llm_strong = load_local_llm(LOCAL_MODEL_NAME, max_new_tokens=MAX_NEW_TOKENS)
# llm_strong = load_google_llm(model_name=GEMINI_MODEL_NAME,max_new_tokens=MAX_NEW_TOKENS)


# Roles
worker = llm_strong
manager = llm_strong
verifier = llm_strong
extractor = llm_strong

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

config.json:   0%|          | 0.00/661 [00:00<?, ?B/s]

model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/3.97G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/2.20G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/242 [00:00<?, ?B/s]

## 4) Text Chunking

In [6]:
splitter = RecursiveCharacterTextSplitter(
    chunk_size=CHUNK_SIZE,
    chunk_overlap=int(CHUNK_SIZE*0.1),
    separators=["\n\n", "\n", ". ", " "]
)

## 5) Prompts (CoA + CoVe + Answer Extraction)

In [7]:
# ---- CoA ----
WORKER_PROMPT = lambda i, query, chunk, prev: f"""
You are Worker {i} in a chain solving a long-context question answering task.

Use ONLY:
- the current source text (chunk)
- the previous worker summary

Task:
- Write a new summary that combines:
  (a) all information from the previous summary that is relevant to the query, and
  (b) any new relevant information in the current chunk.
- If the current chunk adds no new relevant information, simply repeat the previous summary unchanged.

Constraints:
- Maximum length: about 300 tokens.
- Output only the new summary, no commentary about your process.

Query:
{query}

Current source text (CHUNK {i}):
{chunk}

Previous worker summary:
{prev}

Now output the combined summary:
"""


MANAGER_PROMPT = lambda query, final_worker_json: f"""
You are the Manager in a HotpotQA question answering system.

Task:
- Read the summary of evidence.
- Reason briefly about the answer.
- Then output the final answer as a short span, try to find the closest answer.

Output format (very important):
1. First, write a short reasoning paragraph if needed.
2. On the LAST line of your response, write exactly:

   Final answer: <answer>

Rules for <answer>:
- Use the shortest possible span (a name, location, date, number, or "yes"/"no").
- For yes/no questions, answer exactly "yes" or "no".
- Do NOT add any text after <answer> on that line.
- Do NOT write anything after the "Final answer: ..." line (no notes, no extra sentences).

Query:
{query}

Summary of evidence:
{final_worker_json}
"""


# ---- CoVe ----
PLAN_VERIFICATIONS_PROMPT = lambda query, chunk, baseline_summary: f"""
You are verifying a summary used in a long-context QA pipeline.

Original Query: {query}

Source chunk: {chunk}

Baseline summary: {baseline_summary}

Task:
Generate a small list of concrete verification questions (2–4) that help check:
- factual correctness
- coverage of key information relevant to the query
- absence of unsupported claims
Return the verification questions as a numbered list.
"""


EXEC_VERIFICATIONS_PROMPT = lambda query, chunk, qa_block: f"""
You are answering verification questions about a summary for a long-context QA pipeline.

Original Query: {query}

Source chunk: {chunk}

Here is a list of verification questions:
{qa_block}

For each question, answer concisely.
Formatting rules (very important):
- Return your answers as a **single numbered list**.
- Use exactly one line per answer.
- Do NOT repeat the list.
- Do NOT restate the questions.
- The format must be:

  1. <answer to Q1>
  2. <answer to Q2>
  3. <answer to Q3>
  ...
"""


GEN_FINAL_RESPONSE_PROMPT = lambda query, chunk, baseline_summary, questions, answers: f"""
You are revising a summary for a long-context QA pipeline.

Original Query: {query}

Source chunk: {chunk}

Baseline summary: {baseline_summary}

Verification Q&A:
{chr(10).join(f"Q: {q}\nA: {a}" for q, a in zip(questions, answers))}

Task:
Write a revised summary that:
- corrects any factual errors in the baseline summary
- adds missing key information supported by the source chunk
- removes unsupported or speculative claims
- remains concise and focused on information relevant to the question

Return ONLY the revised summary.
"""

# ---- Final Answer Extraction ----
EXTRACT_ANSWER_PROMPT = lambda query, manager_output: f"""
You are post-processing the output of a QA system on the HotpotQA dataset.

Your task: extract the **final answer string** that should be evaluated against the gold answer.

Constraints (very important):
- Return **only** the minimal answer span.
- Do **not** include explanations, reasoning, or extra words.
- Do **not** include phrases like "The answer is", "It is", "Final answer", etc.
- Do **not** add punctuation at the beginning or end unless it is part of the entity (e.g., "U.S.").
- Do **not** output multiple sentences.
- If the question is yes/no, answer with exactly **yes** or **no**.
- If the model’s answer is clearly wrong or missing, output exactly **no answer**.

Output format:
- Your entire response must be **only** the answer string, with no quotation marks and no additional text.

Query:
{query}

Model's answer:
{manager_output}

Now output the answer string only:
"""

## 6) CAVA Pipeline (CoA workers + optional CoVe verification + manager)

In [8]:
class VerificationTrace(TypedDict):
    worker_idx: int
    baseline_summary: str
    verification_questions: List[str]
    verification_answers: List[str]
    verified_summary: str


class CoAState(TypedDict):
    query: str
    chunks: List[str]
    i: int
    worker_outputs: List[str]
    verbose: bool
    manager_output: str

    verification_mode: str # "none" | "every" | "every_k"
    verification_k: int
    store_verification_traces: bool
    verification_traces: List[VerificationTrace]

In [9]:
def worker_node(state: CoAState):
    i = state["i"]
    chunk = state["chunks"][i]
    if i == 0:
        prev = "No Previous summaries"
    else:
        # Get previous worker's output
        prev = state["worker_outputs"][i-1]
    prompt = WORKER_PROMPT(i, state["query"], chunk, prev)
    if state["verbose"]:
        print(f"Worker {i} with Prompt: \n######{prompt}\n#######\n")
        print("worker invoke")
    out = worker.invoke(prompt)

    if state["verbose"]:
        print("worker invoke -- done")
    # Note new outut
    state["worker_outputs"].append(out)
    state["i"] += 1
    if state["verbose"]:
        print(f"Outputs: {out}\n------------------\n\n")

    return state


def manager_node(state:CoAState):
    if state["verbose"]:
        state["worker_outputs"][-1]
    last_worker_output = state["worker_outputs"][-1]
    prompt = MANAGER_PROMPT(state["query"], last_worker_output)
    if state["verbose"]:
        print(f"Manager with Prompt: \n######{prompt}\n#######\n")
    final_answer = manager.invoke(prompt)
    # store final summary as last output
    state["manager_output"] = final_answer
    if state["verbose"]:
        print(f"Manager Final Output: \n#############\n{final_answer}")

    return state

In [10]:
def parse_numbered_answers(exec_text: str, num_questions: int) -> List[str]:
    """
    Parse a numbered list like:
        1. Yes
        2) No
        3 - Maybe
    into ["Yes", "No", "Maybe"], capped at num_questions.

    Behavior:
    - Ignore ALL lines until the first line beginning with a number.
    - After that, parse consecutive numbered answers.
    - Stop after num_questions items.
    - Handle lines like "4. Yes 1. Yes..." by keeping only the first segment.
    """

    answers: List[str] = []
    started = False  # track when we hit first numbered line

    for line in exec_text.split("\n"):
        line = line.strip()
        if not line:
            continue

        # Detect first numbered line
        if not started:
            if line[0].isdigit():
                started = True
            else:
                continue  # skip until list starts

        # From here on, only accept numbered-list lines
        if not line[0].isdigit():
            continue

        # Strip the leading number (1., 1), 1 -, etc.)
        cleaned = re.sub(r"^\d+\s*[\.\)\-]\s*", "", line).strip()

        # Remove any second inline numbering (avoid "Yes 1. Yes, ..."), keep only the part before it
        parts = re.split(r"\s+\d+\s*[\.\)\-]\s*", cleaned)
        cleaned = parts[0].strip()

        if cleaned:
            answers.append(cleaned)
            if len(answers) >= num_questions:
                break

    # Fallback if no valid parsed answers
    if not answers:
        return [exec_text.strip()]

    return answers


def run_cove(query: str, chunk: str, baseline_summary: str, worker_idx: int, verbose: bool = False) -> VerificationTrace:
    # 1. Baseline response = baseline_summary (already produced by worker)
    # 2. Plan verification questions
    plan_prompt = PLAN_VERIFICATIONS_PROMPT(query, chunk, baseline_summary)

    if verbose:
        print(f"[CoVe][Worker {worker_idx}] Plan prompt:\n{plan_prompt}\n")

    plan_resp = verifier.invoke(plan_prompt)
    plan_text = str(getattr(plan_resp, "content", plan_resp)) # Depending on the LLM wrapper, llm.invoke() may return a plain string or a message object AIMessage(..., content="some text", ...)

    # crude parsing: split into lines that look like questions
    questions = [
        line.strip(" -0123456789.").strip()
        for line in plan_text.split("\n")
        if "?" in line
    ]
    questions = [q for q in questions if q]

    if verbose:
        print(f"[CoVe][Worker {worker_idx}] Verification Questions:\n{questions}\n")

    # 3. Execute verifications (factored: one call per question)
    answers = []
    qa_block = "\n".join(f"{i+1}. {q}" for i, q in enumerate(questions))
    exec_prompt = EXEC_VERIFICATIONS_PROMPT(query, chunk, qa_block)
    exec_resp = verifier.invoke(exec_prompt)
    exec_text = str(getattr(exec_resp, "content", exec_resp)).strip() #

    if verbose:
        print(f"verification q answer: {exec_text}")

    # answer parsing
    parsed_answers = parse_numbered_answers(exec_text, num_questions=len(questions))

    if len(parsed_answers) != len(questions):
        questions_for_gen = ["\n".join(questions)]
        answers_for_gen = [exec_text]
    else:
        questions_for_gen = questions
        answers_for_gen = parsed_answers

    if verbose:
        print(f"[CoVe][Worker {worker_idx}] Answers:\n{answers_for_gen}\n")

    # 4. Generate final verified summary
    final_prompt = GEN_FINAL_RESPONSE_PROMPT(query, chunk, baseline_summary, questions_for_gen, answers_for_gen)
    if verbose:
        print(f"final_prompt: {final_prompt}")
    final_resp = verifier.invoke(final_prompt)
    final_summary = str(getattr(final_resp, "content", final_resp)).strip() # [TODO] add in worker node? Depending on the LLM wrapper, llm.invoke() may return a plain string or a message object AIMessage(..., content="some text", ...)

    if verbose:
        print(f"[CoVe][Worker {worker_idx}] Final verified summary:\n{final_summary}\n")

    trace: VerificationTrace = {
        "worker_idx": worker_idx,
        "baseline_summary": baseline_summary,
        "verification_questions": questions,
        "verification_answers": answers,
        "verified_summary": final_summary,
    }
    return trace


def verification_node(state: CoAState, worker_idx: int):
    query = state["query"]
    chunk = state["chunks"][worker_idx]

    raw_summary = state["worker_outputs"][worker_idx]
    baseline_summary = str(getattr(raw_summary, "content", raw_summary)) #

    trace = run_cove(
        query=query,
        chunk=chunk,
        baseline_summary=baseline_summary,
        worker_idx=worker_idx,
        verbose=state["verbose"],
    )

    # replace worker summary with verified one
    state["worker_outputs"][worker_idx] = trace["verified_summary"]

    # store trace if needed
    if state.get("store_verification_traces", False): # defaults to False if the key doesn’t exist
        state["verification_traces"].append(trace)

    return state


def maybe_run_verification(state: CoAState) -> CoAState:
    """
    determine if the latest generated summary needs to be verified

    Param:
    state (returned by worker_node())

    Return:
    updated state
    """
    mode = state["verification_mode"] # "none" | "every" | "every_k"
    k = state["verification_k"]
    current_worker_idx = state["i"] - 1 # cuz in worker_node() before returning state it does state["i"] += 1

    if mode == "none":
        return state
    if mode == "every":
        return verification_node(state, current_worker_idx)
    if mode == "every_k" and (current_worker_idx + 1) % k == 0:
        return verification_node(state, current_worker_idx)

    return state

In [11]:
def run_cava(query, context, verbose=True, verification_mode="none", verification_k=1, store_verification_traces=True, postprocess=True): # chunk_size
    # Split context
    # chunks = split_text(context, chunk_size=chunk_size)
    chunks = splitter.split_text(context)
    if verbose:
        print("Text Chunks: ",chunks)
    # assert 1==2
    # Initialize initial CoAState
    init_state = {
        "query": query,
        "chunks": chunks,
        "i": 0,
        "worker_outputs": [],
        "verbose": verbose,
        "manager_output": "",
        # [CoVe]
        "verification_mode": verification_mode,
        "verification_k": verification_k,
        "store_verification_traces": store_verification_traces,
        "verification_traces": []
    }
    state = init_state
    length = len(chunks)
    if verbose:
        print("Num Chunks: ", length)
    # Worker nodes, for each chunk
    for i, chunk in enumerate(chunks):
        # Run worker node and get new state
        if verbose:
            print(f"Running Worker {i}")
        state = worker_node(state)
        if verbose:
            print(f"Running Worker {i} -- Done")
        # [TODO]
        if verbose:
            print(f"Verifying Worker {i}")
        state = maybe_run_verification(state)
        if verbose:
            print(f"Verifying Worker {i} -- Done")

    # At the end of the loop, state["i"] should be == len(chunks)
    assert state["i"] == len(chunks), "Total states worked does not equal to number of text chunks"

    # Finally run manager at last
    if verbose:
        print(f"Manager producing output")
    state = manager_node(state)
    # final_ans = state["worker_outputs"][-1].content
    final_ans = state["manager_output"]
    if verbose:
        print("Final Answer before process: ", final_ans)
    if "Final Answer: ".lower() in final_ans.lower():
        if verbose:
            print("splitting parsing")
        final_ans = final_ans.lower().split("Final answer: ".lower())[-1]
    if verbose:
        print(f"Manager producing output -- Done")

    # Post processing
    if postprocess and False:
        if verbose:
            print(f"Extractor")
        prompt = EXTRACT_ANSWER_PROMPT(query, state["manager_output"])
        resp = extractor.invoke(prompt)
        final_ans = str(getattr(resp, "content", resp)).strip()
        if verbose:
            print(f"Extractor -- Done")

    if verbose:
        print(f"Query: {state["query"]}\nFinal Answer: {final_ans}")
    return final_ans

## 7) Data & Evaluation

In [12]:
# -----------------------------------------------------------
# 1. HotpotQA Loader
# -----------------------------------------------------------

def load_hotpotqa(split="validation", max_samples=None):
    """
    [source] https://huggingface.co/datasets/hotpotqa/hotpot_qa

    an example in hotpotqa - fullwiki:
    {
        "id": str,
        "question": str,
        "answer": str,
        "type": str,
        "level": str,
        "supporting_facts":
        {
            "title": [str, str, ...], # may repeat
            "sent_id": [int32, int32, ...]
        },
        "context":
        {
            "title": [str, str, ...],
            "sentences": [[str, str, str, ...], [str, str, str, ...], ...]
        }

    }

    Return:
    a list of dicts
    {
        "id": HotpotQA string id,
        "idx": int index within this split (for convenience),
        "question": str,
        "answer": str,
        "type": "bridge" or "comparison",
        "level": "easy"/"medium"/"hard",
        "context":
        [
            { "title": str, "sentences": [str, str, ...] }, # doc 0
            { "title": str, "sentences": [str, str, ...] }, # doc 1
            ...
        ]
    }
    """
    raw = load_dataset("hotpot_qa", "fullwiki")[split]

    # Generate random indicies

    rng = random.Random(SEED)
    indices = list(range(len(raw)))
    rng.shuffle(indices)
    indicies = indices[:max_samples]
    print("Random sampled indicies: ", indicies)

    data = []
    for idx, item in enumerate(raw):
        context = [
            {
                "title": t,
                "sentences": sents
            }
            for t, sents in zip(item["context"]["title"], item["context"]["sentences"])
        ]

        data.append({
            "id": item["id"],                 # original HotpotQA id (string)
            "idx": idx,                       # integer position in this split
            "question": item["question"],
            "answer": item["answer"],
            "type": item.get("type"),         # "bridge" or "comparison"
            "level": item.get("level"),       # "easy"/"medium"/"hard"
            "context": context,
        })

        if max_samples and len(data) >= max_samples:
            break

    return data


def load_hotpotqa_balanced(split="train", per_level=50, levels=("easy", "medium"), seed=SEED):
    """
    Load a balanced subset of HotpotQA examples:
    e.g., 50 easy + 50 medium = 100 total.

    Args:
        split: dataset split ("train" or "validation")
        per_level: number of samples per difficulty level
        levels: tuple of levels to include ("easy", "medium", etc.)
        seed: random seed for reproducibility

    Returns:
        A list of dicts in your pipeline format.
    """

    random.seed(seed)

    raw = load_dataset("hotpot_qa", "fullwiki")[split]

    # Bucket by difficulty level
    buckets = {level: [] for level in levels}

    for item in raw:
        lvl = item["level"]
        if lvl in buckets:
            buckets[lvl].append(item)

    # Sample equally from each level
    selected = []
    for lvl in levels:
        if len(buckets[lvl]) < per_level:
            raise ValueError(f"Not enough samples for level '{lvl}'.")

        selected.extend(random.sample(buckets[lvl], per_level))

    # Shuffle final dataset (optional)
    random.shuffle(selected)

    # Convert to your expected structure
    data = []
    for idx, item in enumerate(selected):
        context = [
            {
                "title": t,
                "sentences": sents
            }
            for t, sents in zip(item["context"]["title"], item["context"]["sentences"])
        ]

        data.append({
            "id": item["id"],                 # original HotpotQA id (string)
            "idx": idx,                       # integer position in this split
            "question": item["question"],
            "answer": item["answer"],
            "type": item.get("type"),         # "bridge" or "comparison"
            "level": item.get("level"),       # "easy"/"medium"/"hard"
            "context": context,
        })

    return data


# -----------------------------------------------------------
# 2. Context Merger
# -----------------------------------------------------------

def merge_context_fullwiki(context):
    """
    merge each document's sentence list into a single text string

    Param:
    context (return from load_hotpotqa()):
    [
        { "title": str, "sentences": [str, str, ...] }, # doc 0
        { "title": str, "sentences": [str, str, ...] }, # doc 1
        ...
    ]

    Return:
    texts: list[str]
    merged text for each document
    """
    texts = []

    for doc in context:
        text = " ".join(doc["sentences"])
        texts.append(text)

    return texts


# -----------------------------------------------------------
# 3. Evaluation Metrics (EM + F1)
# -----------------------------------------------------------

def normalize_answer(s):
    """
    Lowercase, remove punctuation/articles/extra whitespace.
    """
    def remove_articles(text):
        return re.sub(r'\b(a|an|the)\b', ' ', text)

    def white_space_fix(text):
        return ' '.join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return ''.join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))


def f1_score(pred, gold):
    pred_tokens = normalize_answer(pred).split()
    gold_tokens = normalize_answer(gold).split()

    if len(pred_tokens) == 0 or len(gold_tokens) == 0:
        return int(pred_tokens == gold_tokens)

    common = set(pred_tokens) & set(gold_tokens)
    num_same = sum(min(pred_tokens.count(t), gold_tokens.count(t)) for t in common)

    if num_same == 0:
        return 0

    precision = num_same / len(pred_tokens)
    recall = num_same / len(gold_tokens)
    return (2 * precision * recall) / (precision + recall)


def exact_match(pred, gold):
    return normalize_answer(pred) == normalize_answer(gold)


# -----------------------------------------------------------
# 4. Evaluation loop
# -----------------------------------------------------------

def evaluate(model_fn: Callable, dataset: List[Dict], log_path: str | None = None, flush_every: int = 10, include_context_in_log: bool = False):
    """
    model_fn(query, context_chunks) -> str

    If log_path is provided, per-sample metrics will be written to a JSONL file, one JSON object per line.
    This allows:
      - parallel eval on disjoint subsets
      - later merging / error analysis
    """
    f1s = []
    ems = []
    N = 0

    # Open log file once (append mode so you can resume)
    log_file = None
    if log_path is not None:
        os.makedirs(os.path.dirname(log_path), exist_ok=True)
        log_file = open(log_path, "a", encoding="utf-8")
        print("logging outputs in: ", log_path)
    try:
        for sample in tqdm(dataset, desc="Evaluating"):
            qid = sample.get("id")
            question = sample["question"]
            context = sample["context"]
            gold = sample["answer"]

            texts = merge_context_fullwiki(context)
            pred = model_fn(question, texts)

            f1 = f1_score(pred, gold)
            em = int(exact_match(pred, gold))

            f1s.append(f1)
            ems.append(em)
            N += 1

            # Per-sample record for logging
            record = {
                "id": qid,
                "idx": sample.get("idx"),
                "type": sample.get("type"),
                "level": sample.get("level"),
                "question": question,
                "gold_answer": gold,
                "prediction": pred,
                "f1": f1,
                "em": em,
            }
            if include_context_in_log:
                record["context"] = texts # can be big; toggle with flag

            if log_file is not None:
                log_file.write(json.dumps(record, ensure_ascii=False) + "\n")

                # Periodic flush so progress is safely on disk
                if N % flush_every == 0:
                    log_file.flush()
                    os.fsync(log_file.fileno())

        # Final flush
        if log_file is not None:
            log_file.flush()
            os.fsync(log_file.fileno())

    finally:
        if log_file is not None:
            log_file.close()

    return {
        "num_samples": N,
        "f1": sum(f1s) / N if N > 0 else 0.0,
        "em": sum(ems) / N if N > 0 else 0.0,
    }


# -----------------------------------------------------------
# 5. Placeholder CoA model
# -----------------------------------------------------------
def raw_model(question, context):
    # prompt = f"Context: {context}\nQuestion: {question}\nAnswer: "
    prompt = f"""
    You are solving a long-context question answering task.

    Use ONLY the source text to find information that is relevant to the query

    Output format (very important):
    1. First, write a short reasoning paragraph if needed.
    2. On the LAST line of your response, write exactly:

    Final answer: <answer>

    Rules for <answer>:
    - Use the shortest possible span (a name, location, date, number, or “yes”/“no”).
    - For yes/no questions, answer exactly “yes” or “no”.
    - Do NOT add any text after <answer> on that line.
    - Do NOT write anything after the “Final answer: ...” line (no notes, no extra sentences).

    Query:
    {question}

    Source text:
    {context}
    """
    resp = llm_strong.invoke(prompt)
    final_ans = str(getattr(resp, "content", resp)).strip()
    return final_ans


def cava_wrapper(question: str, context_texts: List[str]) -> str:
    merged_context = " ".join(context_texts)

    # == Run CAVA or CoA ==
    final_ans = run_cava(query=question, context=merged_context, verbose=False, verification_mode=VERIFICATION_MODE, verification_k=VERIFICATION_K, store_verification_traces=True)

    # == Run Full Context Baseline ==
    # final_ans = raw_model(question, merged_context)

    return final_ans


# -----------------------------------------------------------
# 6. Running the pipeline
# -----------------------------------------------------------

if __name__ == "__main__":
    # Full Hotpot QA fullwiki "train": 90447, "validation": 7405
    # Choose one of the following dataset loaders:

    # Option 1: Load a random subset from the validation split (e.g., harder questions)
    # data = load_hotpotqa(split="validation", max_samples=NUM_SAMPLES_TO_LOAD)

    # Option 2: Load a balanced subset from the training split by difficulty level
    data = load_hotpotqa_balanced(split="train", per_level=NUM_SAMPLES_TO_LOAD_PER_LEVEL, levels=LEVELS, seed=SEED)

    subset = data[SUBSET_START:SUBSET_END]
    results = evaluate(cava_wrapper, subset, log_path=LOG_PATH, flush_every=FLUSH_EVERY, include_context_in_log=INCLUDE_CONTEXT_IN_LOG)



README.md: 0.00B [00:00, ?B/s]

fullwiki/train-00000-of-00002.parquet:   0%|          | 0.00/166M [00:00<?, ?B/s]

fullwiki/train-00001-of-00002.parquet:   0%|          | 0.00/166M [00:00<?, ?B/s]

fullwiki/validation-00000-of-00001.parqu(…):   0%|          | 0.00/28.0M [00:00<?, ?B/s]

fullwiki/test-00000-of-00001.parquet:   0%|          | 0.00/27.6M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/90447 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/7405 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/7405 [00:00<?, ? examples/s]

logging outputs in:  /content/drive/MyDrive/GenAI/project/logs/hotpotqa_qwen3B_every3_test.jsonl


Evaluating:   0%|          | 0/10 [00:00<?, ?it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Evaluating: 100%|██████████| 10/10 [05:26<00:00, 32.67s/it]


In [13]:
results

{'num_samples': 10, 'f1': 0.24695064232589586, 'em': 0.2}

## 8) Log Inspection

In [14]:
print(LOG_PATH)
data = []
with open(LOG_PATH, "r") as f:
    for line in f:
        data.append(json.loads(line))
        print(json.loads(line))

print(len(data))

/content/drive/MyDrive/GenAI/project/logs/hotpotqa_qwen3B_every3_test.jsonl
{'id': '5ac0ce595542997d64295a54', 'idx': 0, 'type': 'bridge', 'level': 'medium', 'question': 'What animal symbol of Gran Cararia was involved in the death of Diane Whipple in 2001?', 'gold_answer': 'Perro de Presa Canario', 'prediction': 'canarian houbara', 'f1': 0, 'em': 0}
{'id': '5ac159125542994ab5c67ced', 'idx': 1, 'type': 'bridge', 'level': 'medium', 'question': 'In the 2001 census what was the population of township in which Fernyhalgh Wood is located ?', 'gold_answer': '33,171', 'prediction': '2,879', 'f1': 0, 'em': 0}
{'id': '5a8f62c755429918e830d20d', 'idx': 2, 'type': 'bridge', 'level': 'easy', 'question': 'Craig "Chief" Berube (born December 17, 1965) is a Canadian former professional ice hockey player and the former head coach of the Chicago Wolves, a professional ice hockey team playing in the Central Division of the Western Conference, of which organization?', 'gold_answer': 'American Hockey Leag

In [15]:
def load_metrics_from_jsonl(path):
    f1s = []
    ems = []
    samples = []

    with open(path, "r") as f:
        for line in f:
            if not line.strip():
                continue
            entry = json.loads(line)
            samples.append(entry)
            f1s.append(entry.get("f1", 0))
            ems.append(entry.get("em", 0))

    mean_f1 = sum(f1s) / len(f1s) if f1s else 0
    mean_em = sum(ems) / len(ems) if ems else 0

    return {
        "num_samples": len(samples),
        "mean_f1": mean_f1,
        "mean_em": mean_em,
        "samples": samples,
    }

res = load_metrics_from_jsonl(LOG_PATH)
print("Samples:", res["num_samples"])
print("Mean F1:", res["mean_f1"])
print("Mean EM:", res["mean_em"])


Samples: 30
Mean F1: 0.2469506423258959
Mean EM: 0.2
