In [None]:
import os, json, random, re
import numpy as np
import pandas as pd
from tqdm import tqdm
import matplotlib.pyplot as plt
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification


In [None]:
# Download the dataset
!mkdir -p dataset
!gsutil cp gs://gresearch/ASQA/data/ASQA.json dataset/
import json
import pandas as pd
from collections import defaultdict

# Load the raw JSON
with open('dataset/ASQA.json', 'r') as f:
    data = json.load(f)

# Create a list to store all flattened records
all_records = []

# Process each split
for split, items_dict in data.items():
    for item_id, item in items_dict.items():
        # Create a record with basic information
        record = {
            'id': item_id,
            'split': split,
            'ambiguous_question': item.get('ambiguous_question', ''),
            'source_dataset': item.get('metadata', {}).get('dataset', '') if isinstance(item.get('metadata'), dict) else ''
        }

        # Process QA pairs
        qa_pairs = item.get('qa_pairs', [])
        if qa_pairs:
            # Take the first QA pair as the main one
            main_qa = qa_pairs[0]
            record.update({
                'question': main_qa.get('question', ''),
                'short_answers': ', '.join(main_qa.get('short_answers', [])),
                'context': main_qa.get('context', ''),
                'wikipage': main_qa.get('wikipage', '')
            })

        # Process annotations
        annotations = item.get('annotations', [])
        if annotations:
            # Take the first annotation as the main one
            main_annotation = annotations[0]
            record.update({
                'reference_answer': main_annotation.get('long_answer', ''),
                'additional_answers': ' | '.join(main_annotation.get('additional_answers', [])),
                'annotator_id': main_annotation.get('annotator_id', '')
            })

        all_records.append(record)

# Create DataFrame
asqa_df = pd.DataFrame(all_records)

# Add source dataset if missing
if 'source_dataset' not in asqa_df.columns:
    asqa_df['source_dataset'] = asqa_df['id'].apply(lambda x: 'nq' if 'nq' in x else 'wq')

# Filter for dev split
dev_df = asqa_df[asqa_df['split'] == 'dev'].drop('split', axis=1)

print(f"Created DataFrame with {len(dev_df)} dev records")
print("\nDataFrame columns:", dev_df.columns.tolist())
print("\nFirst row:")
print(dev_df.iloc[0])

import os
if not os.path.exists("/content/ALCE"):
    print("Cloning ALCE repository...")
    os.system("git clone https://github.com/princeton-nlp/ALCE.git /content/ALCE")

os.chdir("/content/ALCE")
os.system("bash download_data.sh")

Copying gs://gresearch/ASQA/data/ASQA.json...
\ [1 files][ 13.9 MiB/ 13.9 MiB]                                                
Operation completed over 1 objects/13.9 MiB.                                     
Created DataFrame with 948 dev records

DataFrame columns: ['id', 'ambiguous_question', 'source_dataset', 'question', 'short_answers', 'context', 'wikipage', 'reference_answer', 'additional_answers', 'annotator_id']

First row:
id                                                 -7013890438520559398
ambiguous_question         Who has the highest goals in world football?
source_dataset                                                         
question              Who has the highest goals in men's world inter...
short_answers                                            Daei, Ali Daei
context                                             No context provided
wikipage                                                           None
reference_answer      Ali Dael has the highest goals in men

0

In [None]:
import os
import json
import pandas as pd
from tqdm import tqdm

# 1. Point to the directory containing your JSON retrieval files
RETRIEVAL_DIR = "/content/ALCE/data"   # adjust this if your path is different

# 2. List all *.json files in that directory that match the retrieval‐file pattern
all_files = os.listdir(RETRIEVAL_DIR)
json_files = [f for f in all_files if f.endswith(".json") and "eval" in f]

print("Found JSON retrieval files:")
for fn in json_files:
    print("  ", fn)

# 3. Helper function: load a JSON file (either a JSON array or JSONL) into a Python list
def load_json_as_list(fullpath):
    """
    Loads either a JSON array (e.g. [ {...}, {...}, ... ])
    or a JSON‐lines file (one JSON object per line) into a Python list.
    """
    with open(fullpath, "r", encoding="utf8") as f:
        text = f.read().strip()
        # If it begins with '[' and ends with ']', assume it's a JSON array:
        if text.startswith("[") and text.endswith("]"):
            return json.loads(text)
        else:
            # Otherwise parse line‐by‐line
            objs = []
            for line in text.splitlines():
                line = line.strip()
                if not line:
                    continue
                try:
                    objs.append(json.loads(line))
                except json.JSONDecodeError:
                    # skip malformed lines
                    pass
            return objs

# 4. Load each file into a DataFrame (keeping “docs” nested list intact)
dfs = {}
for fn in json_files:
    full_path = os.path.join(RETRIEVAL_DIR, fn)
    data_list = load_json_as_list(full_path)

    df = pd.DataFrame(data_list)


    if "docs" in df.columns:
        df["docs_str"] = df["docs"].apply(lambda L: json.dumps(L) if isinstance(L, list) else "")
    if "retrieved_passages" in df.columns:
        df["docs_str"] = df["retrieved_passages"].apply(lambda L: json.dumps(L) if isinstance(L, list) else "")

    key = fn.replace(".json", "")
    dfs[key] = df
asqa_oracle = dfs["asqa_eval_gtr_top100_reranked_oracle"]
asqa_oracle = asqa_oracle.rename(columns={"sample_id": "id"})

# Keep only the columns we need from asqa_oracle:
oracle_subset = asqa_oracle[["id", "wikipages", "docs"]].copy()

# Merge into one big DataFrame: df_all
df_all = pd.merge(dev_df, oracle_subset, how="inner", on="id")
print(f"df_all shape: {df_all.shape}")   # e.g. (1172, …)

# Now df_all has columns:
#   [ 'id', 'ambiguous_question', 'source_dataset', 'question', 'short_answers',
#     'context', 'wikipage', 'reference_answer', 'additional_answers',
#     'annotator_id', 'wikipages', 'docs' ]
#
#   • df_all['docs'] is a Python list of dicts: [ { 'id': '...', 'title': '…' }, … ]
#   • df_all['wikipages'] is the gold set: [ { 'title':'…', 'url':'…' }, … ]
# ─────────────────────────────────────────────────────────────────────────────

Found JSON retrieval files:
   asqa_eval_dpr_top100.json
   qampari_eval_dpr_top100.json
   qampari_eval_gtr_top100.json
   eli5_eval_bm25_top100_reranked_oracle.json
   asqa_eval_gtr_top100.json
   asqa_eval_gtr_top100_reranked_oracle.json
   eli5_eval_bm25_top100.json
   qampari_eval_gtr_top100_reranked_oracle.json
df_all shape: (948, 12)


In [None]:
from glob import glob

# Suppose you cloned ALCE and ran download_data.sh, so the JSON file lives at:
oracle_path = "/content/ALCE/data/asqa_eval_gtr_top100_reranked_oracle.json"

# Load as either a JSON array or JSONL
with open(oracle_path, "r", encoding="utf8") as f:
    text = f.read().strip()
    if text.startswith("[") and text.endswith("]"):
        oracle_list = json.loads(text)
    else:
        oracle_list = [json.loads(line) for line in text.splitlines() if line.strip()]

asqa_oracle = pd.DataFrame(oracle_list)

# Rename `sample_id` → `id` so it matches dev_df
asqa_oracle = asqa_oracle.rename(columns={"sample_id": "id"})

# Keep only [id, wikipages, docs] for merging
oracle_subset = asqa_oracle[["id", "wikipages", "docs"]].copy()

# 2b) Merge with dev_df (which you have from your flattening code)
df_all = pd.merge(dev_df, oracle_subset, how="inner", on="id")
print(f"After merging ASQA‐dev with oracle retrieval, df_all.shape = {df_all.shape}")


After merging ASQA‐dev with oracle retrieval, df_all.shape = (948, 12)


In [None]:
N_DOCS       = 5     # number of passages shown / citation indices
MAX_TOKENS   = 700
TEMPERATURE  = 0.0

In [None]:
pip install groq

Collecting groq
  Downloading groq-0.31.0-py3-none-any.whl.metadata (16 kB)
Downloading groq-0.31.0-py3-none-any.whl (131 kB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/131.4 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m131.4/131.4 kB[0m [31m13.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: groq
Successfully installed groq-0.31.0


In [None]:
from groq import Groq

# Initialize the Groq client with your API key (set your actual key here or via an environment variable)
client = Groq(
    api_key="g"   # ← replace with your real key (or load it from os.environ)
)

In [None]:
def generate_with_enforcement(builder, row, n_docs=N_DOCS):
    out = llama_generate(builder(row), max_tokens=MAX_TOKENS, temperature=TEMPERATURE)
    # simple check for [number] in the final answer
    answer_text = extract_answer_text(out)
    if not re.search(r'\[\s*\d+\s*\]', answer_text):
        # ask the model to rewrite with citations
        fix_prompt = f"""Rewrite the FINAL ANSWER below so that EVERY sentence ends with a numeric citation [k] (k∈1..{n_docs}).
Use only square brackets. Do not change the content.

Passages:
{build_rag_block(row, n_docs)}

Original:
{out}
"""
        out = llama_generate(fix_prompt, max_tokens=MAX_TOKENS, temperature=0.0)
    return out


In [None]:
def llama_generate(
    prompt: str,
    max_tokens: int = 512,
    temperature: float = 0.0,
    top_p: float = 1.0
) -> str:
    """
    Send prompt to Llama 3.3 8B via Groq and return the generated text.
    """
    # 1. Call the chat/completion endpoint
    chat_completion = client.chat.completions.create(
        model="llama3-8b-8192",           # or the exact Llama 3.3 8B model name
        messages=[{"role": "user", "content": prompt}],
        max_tokens=max_tokens,
        temperature=temperature,
        top_p=top_p,
        # You can also pass stop_sequences=[...] if you want to force a stop.
    )

    # 2. Extract the generated text via attribute access
    #    (NOT via chat_completion["choices"][0]["message"]["content"])
    generated_text = chat_completion.choices[0].message.content
    return generated_text.strip()

In [None]:
# Extract only the 'Answer:' portion before 'Citations:'
ANSWER_PAT = re.compile(r'(?is)answer\\s*:\\s*(.*?)(?:\\n\\s*citations\\s*:\\s*|$)')

def extract_answer_text(output: str) -> str:
    m = ANSWER_PAT.search(output)
    return m.group(1).strip() if m else output.split("Citations:")[0].strip()

# Split into sentences
SENT_SPLIT = re.compile(r'(?<=[.!?])\\s+(?=[A-Z0-9])')
def segment_statements(output: str) -> list:
    text = extract_answer_text(output)
    parts = SENT_SPLIT.split(text)
    return [p.strip() for p in parts if p.strip()]

# Accept [n], 【n】, or (n)
# one capture group, any “opening bracket” char
# Accept multiple styles of brackets
CITE_RE = re.compile(
    r"[［\[\(【〔]\s*(\d+)\s*[］\]\)】〕]"  # any variant of brackets with a number inside
)

def align_citations(output: str) -> dict:
    stmts = segment_statements(output)
    cmap = {}
    for i, s in enumerate(stmts):
        # use a set to deduplicate repeated citations like [1][1]
        cmap[i] = sorted({int(n) for n in CITE_RE.findall(s)})
    return cmap





In [None]:
def extract_answer_text(output: str) -> str:
    m = ANSWER_PAT.search(output)
    return m.group(1).strip() if m else output.split("Citations:")[0].strip()

# Split into sentences
SENT_SPLIT = re.compile(r'(?<=[.!?])\\s+(?=[A-Z0-9])')
def segment_statements(output: str) -> list:
    text = extract_answer_text(output)
    parts = SENT_SPLIT.split(text)
    return [p.strip() for p in parts if p.strip()]

# Accept [n], 【n】, or (n)
# one capture group, any “opening bracket” char
# Accept multiple styles of brackets
CITE_RE = re.compile(
    r"[［\[\(【〔]\s*(\d+)\s*[］\]\)】〕]"  # any variant of brackets with a number inside
)

In [None]:
import nltk
nltk.download('punkt_tab')

[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt_tab.zip.


True

In [None]:

import os, re, json, math, random, time
import numpy as np
import pandas as pd
from tqdm import tqdm

# -----------------------------
# Config
# -----------------------------
RANDOM_SEED   = 2025
SAMPLE_FRAC   = 1
N_DOCS        = 5
BOOT_B        = 1000
ALPHA         = 0.05
np.random.seed(RANDOM_SEED)
random.seed(RANDOM_SEED)


FORMAT_RULES = (
    "Answer using only the passages below. Be concise and precise. "
    "In the FINAL ANSWER, append a citation like [k] after every sentence, "
    "where k is the number of the supporting passage shown (1..{n}). "
    "If more than one passage supports a sentence, use separate brackets, e.g., [1][3]. "
    "Use square brackets only. Do not restate the question. "
    "Avoid unsupported generalizations (e.g., 'widely recognized', 'unlikely')."
)

def _prefix(strategy_text: str, n_docs: int) -> str:
    return (
        f"{strategy_text}\n"
        + FORMAT_RULES.format(n=n_docs)
        + "\n\nQuestion:\n{q}\n\nRetrieved Passages (numbered):\n{rag}\n"
        "\nFINAL ANSWER:\n"
    )

# -----------------------------
# Build the RAG block (numbered passages)
# -----------------------------
def build_rag_block(row: pd.Series, n_docs: int = N_DOCS) -> str:
    lines = [f"Retrieved Passages (top {n_docs}/100):"]
    for i, p in enumerate(row["docs"][:n_docs], start=1):
        title = p.get("title", "N/A")
        text  = (p.get("text","") or "").replace("\n", " ").strip()
        snippet = " ".join(text.split()[:100])
        lines.append(f"{i}. Title: {title}\n   {snippet}\n")
    return "\n".join(lines)

# -----------------------------
# Prompt builders (few-shot keeps exemplars)
# -----------------------------
def zero_shot_explicit_rag_prompt(row, n_docs: int = N_DOCS) -> str:
    q, rag = row["question"], build_rag_block(row, n_docs)
    strategy = "Zero-shot answering with inline evidence."
    return _prefix(strategy, n_docs).format(q=q, rag=rag)

def cot_standard_explicit_rag_prompt(row, n_docs: int = N_DOCS) -> str:
    q, rag = row["question"], build_rag_block(row, n_docs)
    strategy = ("Reason briefly using the passages, then output only the FINAL ANSWER. "
                "Do not show your reasoning.")
    return _prefix(strategy, n_docs).format(q=q, rag=rag)

def chain_of_citation_prompt(row, n_docs: int = N_DOCS) -> str:
    q, rag = row["question"], build_rag_block(row, n_docs)
    strategy = ("Use the passages to reason step by step, appending [k] to each factual step. "
                "THEN provide the FINAL ANSWER with [k] after every sentence.")
    return _prefix(strategy, n_docs).format(q=q, rag=rag)

def chain_of_quote_prompt(row, n_docs: int = N_DOCS) -> str:
    q, rag = row["question"], build_rag_block(row, n_docs)
    strategy = ("Build your reasoning step by step. At each step, include a short direct quote "
                "from the supporting passage and append [k]. THEN provide the FINAL ANSWER with [k] "
                "after every sentence.")
    return _prefix(strategy, n_docs).format(q=q, rag=rag)

def chain_of_verification_prompt(row, n_docs: int = N_DOCS) -> str:
    q, rag = row["question"], build_rag_block(row, n_docs)
    strategy = ("Draft your answer, then verify each claim against the passages. "
                "Keep only claims that are directly supported. Finally output the FINAL ANSWER "
                "with [k] after every sentence.")
    return _prefix(strategy, n_docs).format(q=q, rag=rag)

def self_verification_prompt(row, n_docs: int = N_DOCS) -> str:
    q, rag = row["question"], build_rag_block(row, n_docs)
    strategy = ("First, produce an answer with inline citations. Then self-check each sentence: "
                "if its citation does not support it, revise or remove it. Output only the FINAL ANSWER.")
    return _prefix(strategy, n_docs).format(q=q, rag=rag)

def motivation_prompt(row, n_docs: int = N_DOCS) -> str:
    q, rag = row["question"], build_rag_block(row, n_docs)
    strategy = ("Your output will be evaluated for factual accuracy and proper attribution. "
                "Incorrect, missing, or hallucinated citations will be penalized.")
    return _prefix(strategy, n_docs).format(q=q, rag=rag)

def role_prompt(row, n_docs: int = N_DOCS) -> str:
    q, rag = row["question"], build_rag_block(row, n_docs)
    strategy = ("You are a researcher specializing in NLP/ML. Use the passages to answer concisely.")
    return _prefix(strategy, n_docs).format(q=q, rag=rag)

# Few-shot: exemplar format only (no content leakage)
def few_shot_qa_explicit_rag_prompt(row, example_ids, n_docs: int = N_DOCS) -> str:
    demos = []
    for ex_id in example_ids:
        ex = df_all[df_all["id"] == ex_id].iloc[0]
        t1 = ex["docs"][0].get("title","Passage 1") if len(ex["docs"])>=1 else "Passage 1"
        t2 = ex["docs"][1].get("title","Passage 2") if len(ex["docs"])>=2 else "Passage 2"
        demos.append(
            f"Question: {ex['question']}\n"
            f"Answer: Fact A. [1] Fact B. [2]\n"
            f"Citations:\n1. {t1}\n2. {t2}\n"
        )
    demo_block = "\n".join(demos)
    q, rag = row["question"], build_rag_block(row, n_docs)
    header = "Few-shot answering with inline evidence.\n" + FORMAT_RULES.format(n=n_docs)
    return (
        header + "\n\nExamples:\n" + demo_block +
        f"\nQuestion:\n{q}\n\nRetrieved Passages (numbered):\n{rag}\n\nFINAL ANSWER:\n"
    )

# Strategy registry
STRATEGY_BUILDERS = {
    #"zero-shot":        zero_shot_explicit_rag_prompt,
    #"few-shot":         few_shot_qa_explicit_rag_prompt,   # needs example_ids
    "chain-of-thought": cot_standard_explicit_rag_prompt,
    "chain-citation":   chain_of_citation_prompt,
    "chain-quote":      chain_of_quote_prompt,
    #"chain-verify":     chain_of_verification_prompt,
    #"self-verify":      self_verification_prompt,
    #"motivation":       motivation_prompt,
    #"role":             role_prompt,
}

# -----------------------------
# Answer normalization & parsing
# -----------------------------
QUESTION_LINE_RE = re.compile(r"(?im)^\s*question\s*:\s*")
ANSWER_LINE_RE   = re.compile(r"(?im)^\s*answer\s*:\s*")
FINAL_ANSWER_RE  = re.compile(r"(?im)^\s*final\s*answer\s*:\s*")
CITATIONS_HDR_RE = re.compile(r"(?im)^\s*Citations\s*:\s*$")

STRIP_BLOCKS = [
    re.compile(r"(?ims)^\s*draft\s*answer\s*:\s*.*?(?=^\S|\Z)"),
    re.compile(r"(?ims)^\s*reasoning\s*:\s*.*?(?=^\S|\Z)"),
    re.compile(r"(?ims)^\s*here\s+is\s+the\s+step-?by-?step.*?(?=^\S|\Z)"),
]

def normalize_final_answer(output: str, row=None) -> str:
    if not output:
        return output
    m = FINAL_ANSWER_RE.search(output)
    if m:
        output = output[m.end():].strip()
    else:
        m2 = ANSWER_LINE_RE.search(output)
        if m2:
            output = output[m2.end():].strip()
    lines = [ln for ln in output.splitlines() if not QUESTION_LINE_RE.match(ln)]
    output = "\n".join(lines).strip()
    for pat in STRIP_BLOCKS:
        output = pat.sub("", output).strip()
    if row is not None and isinstance(row.get("question"), str):
        q = row["question"].strip()
        output = "\n".join(ln for ln in output.splitlines() if ln.strip() != q).strip()
    return output

# Segment sentences (prefer NLTK if available)
try:
    from nltk.tokenize import sent_tokenize as _nltk_sent_tokenize
    _HAS_NLTK = True
except Exception:
    _HAS_NLTK = False

def segment_statements(output: str) -> list:
    if not output:
        return []
    text = output.split("Citations:")[0].strip()
    if _HAS_NLTK:
        sents = _nltk_sent_tokenize(text)
    else:
        sents = re.split(r'(?<=[\.!?])\s+', text)
    sents = [s.strip() for s in sents if s and not s.strip().startswith(("Draft Answer:", "Reasoning:", "Here is the step", "FINAL ANSWER:"))]
    return sents

# Robust inline citation parser (accepts [1], （1）, 【1】, (1), etc.)
CITE_RE = re.compile(
    r"\[\s*(\d+)\s*\]"   # [1]
    r"|［\s*(\d+)\s*］"  # full-width
    r"|【\s*(\d+)\s*】"  # black-lenticular
    r"|\(\s*(\d+)\s*\)" # (1)
    r"|〔\s*(\d+)\s*〕"  # Japanese bracket
)

def align_citations(output: str) -> dict:
    sents = segment_statements(output)
    cmap = {}
    for i, s in enumerate(sents):
        idxs = []
        for m in CITE_RE.finditer(s):
            idxs += [int(g) for g in m.groups() if g]
        cmap[i] = idxs
    return cmap

# Ensure we can display a Citations block (optional, not used for scoring)
def ensure_citations_block(row, output: str) -> str:
    if not output or CITATIONS_HDR_RE.search(output):
        return output
    cmap = align_citations(output)
    used = sorted({idx for ids in cmap.values() for idx in ids if 1 <= idx <= len(row["docs"])})
    if not used:
        return output
    lines = ["", "Citations:"]
    for i, k in enumerate(used, start=1):
        title = row["docs"][k-1].get("title", f"Passage {k}")
        lines.append(f"{i}. {title}")
    return output.rstrip() + "\n" + "\n".join(lines) + "\n"

# -----------------------------
# Overlap (title-based)
# -----------------------------
def canon_title(t: str) -> str:
    t = (t or "").lower().strip()
    t = re.sub(r"\s+", " ", t)
    t = re.sub(r"\s*\([^)]*\)\s*$", "", t)  # drop trailing (…)
    return t

def titles_cited_by_indices(row, output: str):
    cmap = align_citations(output)
    idxs = sorted({idx for ids in cmap.values() for idx in ids if 1 <= idx <= len(row["docs"])})
    return [row["docs"][i-1].get("title", "") for i in idxs]

def overlap_metrics(row, output: str):
    pred_titles = {canon_title(t) for t in titles_cited_by_indices(row, output) if t}
    gold_titles = {canon_title(w.get("title","")) for w in (row.get("wikipages") or []) if w.get("title")}
    tp = len(pred_titles & gold_titles)
    prec = tp / len(pred_titles) if pred_titles else 0.0
    rec  = tp / len(gold_titles) if gold_titles else 0.0
    return prec, rec

# -----------------------------
# NLI scorer (MNLI model)
# -----------------------------
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForSequenceClassification

DEVICE      = "cuda" if torch.cuda.is_available() else "cpu"
MNLI_MODEL  = "roberta-large-mnli"  # or "microsoft/deberta-v3-large"
ENTAIL_THR  = 0.50                   # tune if needed

tok = AutoTokenizer.from_pretrained(MNLI_MODEL)
clf = AutoModelForSequenceClassification.from_pretrained(MNLI_MODEL).to(DEVICE).eval()
# find entailment index robustly
ENTAIL_ID = None
for i, lab in clf.config.id2label.items():
    if "entail" in lab.lower(): ENTAIL_ID = int(i); break
assert ENTAIL_ID is not None, "Could not find ENTAILMENT label."

def entail_prob(premise: str, hypothesis: str) -> float:
    if not premise or not hypothesis:
        return 0.0
    inputs = tok(premise, hypothesis, return_tensors="pt", truncation=True, max_length=512).to(DEVICE)
    with torch.no_grad():
        logits = clf(**inputs).logits
    probs = F.softmax(logits, dim=-1).squeeze().detach().cpu().numpy()
    return float(probs[ENTAIL_ID])

def check_entailment(premise: str, hypothesis: str, thr: float = ENTAIL_THR) -> bool:
    try:
        return entail_prob(premise, hypothesis) >= thr
    except Exception:
        return False

# -----------------------------
# ALCE-style NLI metrics
# -----------------------------
def citation_recall(row, output: str) -> float:
    """Fraction of sentences fully supported by the union of their cited passages.
       Sentences without any citations count as not supported."""
    statements = segment_statements(output)
    cmap = align_citations(output)
    if not statements:
        return 0.0
    supported = 0
    for i, stmt in enumerate(statements):
        idxs = [k for k in cmap.get(i, []) if 1 <= k <= len(row["docs"])]
        if not idxs:
            continue
        premise = " ".join(row["docs"][k-1]["text"] for k in idxs)
        if check_entailment(premise, stmt):
            supported += 1
    return supported / len(statements)

def citation_precision(row, output: str) -> float:
    """ALCE precision: for sentences that are supported by the union,
       a citation c is irrelevant iff (¬entails(c) AND entails(rest_of_cites)).
       Returns (#relevant citations)/(#all citations).
    """
    stmts = segment_statements(output)
    cmap  = align_citations(output)
    total, relevant = 0, 0

    for i, stmt in enumerate(stmts):
        idxs = [k for k in cmap.get(i, []) if 1 <= k <= len(row["docs"])]
        # gather full set text
        Ci_texts = [row["docs"][k-1]["text"] for k in idxs]
        # statement-level support check
        if not Ci_texts or not check_entailment(" ".join(Ci_texts), stmt):
            total += len(idxs)
            continue
        # evaluate each citation's relevance
        for pos, k in enumerate(idxs):
            total += 1
            text_k = row["docs"][k-1]["text"]
            entails_k = check_entailment(text_k, stmt)
            rest_texts = Ci_texts[:pos] + Ci_texts[pos+1:]
            entails_rest = check_entailment(" ".join(rest_texts), stmt) if rest_texts else False
            # relevant unless (¬entails_k AND entails_rest)
            if not ((not entails_k) and entails_rest):
                relevant += 1

    return (relevant / total) if total > 0 else 0.0

# -----------------------------
# Few-shot example selection (fixed across run)
# -----------------------------
EXAMPLE_IDS_PATH = "fewshot_example_ids.json"
NUM_EXAMPLES = 3

def has_two_docs(qid):
    ex = df_all[df_all["id"] == qid].iloc[0]
    return isinstance(ex.get("docs", []), list) and len(ex["docs"]) >= 2

def get_fixed_example_ids():
    pool_ids = df_all["id"].tolist()
    if os.path.exists(EXAMPLE_IDS_PATH):
        with open(EXAMPLE_IDS_PATH, "r") as f:
            ids = json.load(f)
        ids = [i for i in ids if i in pool_ids]
        if len(ids) >= min(NUM_EXAMPLES, len(pool_ids)):
            return ids
    # else pick fresh
    candidates = [i for i in pool_ids if has_two_docs(i)]
    base = candidates if len(candidates) >= NUM_EXAMPLES else pool_ids
    ids = random.sample(base, k=min(NUM_EXAMPLES, len(base)))
    with open(EXAMPLE_IDS_PATH, "w") as f:
        json.dump(ids, f)
    return ids

# -----------------------------
# Adapter for your enforcement caller
# -----------------------------
def make_builder_adapter(name, builder, example_ids, n_docs):
    if name == "few-shot":
        return lambda r, n_docs_param=None: builder(r, example_ids, n_docs=(n_docs_param or n_docs))
    else:
        return lambda r, n_docs_param=None: builder(r, n_docs=(n_docs_param or n_docs))

# -----------------------------
# Main evaluation loop
# -----------------------------
def evaluate_asqa(sample_frac=SAMPLE_FRAC, n_docs=N_DOCS):
    example_ids = get_fixed_example_ids()
    print("Few-shot example IDs:", example_ids)

    # sample or full
    data = df_all.sample(frac=sample_frac, random_state=RANDOM_SEED) if sample_frac < 1.0 else df_all.copy()
    data = data.reset_index(drop=True)

    rows = []
    for strat, builder in STRATEGY_BUILDERS.items():
        adapter = make_builder_adapter(strat, builder, example_ids, n_docs)
        print(f"\n=== Strategy: {strat} on {len(data)} items ===")
        for _, row in tqdm(data.iterrows(), total=len(data), desc=strat):
            try:
                out = generate_with_enforcement(adapter, row, n_docs=n_docs)  # your function
            except TypeError:
                # fallback: call prompt then your base generator if you have it
                prompt = adapter(row, n_docs_param=n_docs)
                out = llama_generate(prompt)  # only if you have llama_generate
            out = normalize_final_answer(out, row)
            out = ensure_citations_block(row, out)  # optional readability

            po, ro = overlap_metrics(row, out)
            rn     = citation_recall(row, out)
            pn     = citation_precision(row, out)

            rows.append({
                "strategy": strat,
                "id": row["id"],
                "precision_overlap": po,
                "recall_overlap":    ro,
                "precision_nli":     pn,
                "recall_nli":        rn,
                "answer": out,   # keep the raw answer for audits
            })
    res = pd.DataFrame(rows)
    return res

# -----------------------------
# Bootstrap confidence intervals (per strategy, per metric)
# -----------------------------
def bootstrap_ci(df, metric: str, group_col: str = "strategy",
                 B: int = BOOT_B, alpha: float = ALPHA, seed: int = RANDOM_SEED):
    rng = np.random.default_rng(seed)
    out = []
    for strat, sub in df.groupby(group_col):
        vals = sub[metric].to_numpy()
        n = len(vals)
        if n == 0:
            continue
        boots = np.empty(B, dtype=float)
        for b in range(B):
            idx = rng.integers(0, n, size=n)
            boots[b] = vals[idx].mean()
        lo = float(np.percentile(boots, 100*(alpha/2)))
        hi = float(np.percentile(boots, 100*(1 - alpha/2)))
        out.append({
            "strategy": strat,
            "metric": metric,
            "mean": float(vals.mean()),
            "ci_low": lo,
            "ci_high": hi,
            "n": int(n),
        })
    return pd.DataFrame(out).sort_values(["metric","strategy"])

def summarize_with_ci(results_df: pd.DataFrame):
    # mean table
    mean_tbl = (
        results_df
        .groupby("strategy")[["precision_overlap","recall_overlap","precision_nli","recall_nli"]]
        .mean()
        .reset_index()
        .sort_values("strategy")
    )

    # CI tables per metric
    ci_frames = []
    for m in ["precision_overlap","recall_overlap","precision_nli","recall_nli"]:
        ci_frames.append(bootstrap_ci(results_df, metric=m))
    ci_df = pd.concat(ci_frames, ignore_index=True)

    # Nice wide format for quick viewing
    wide = (
        ci_df
        .assign(ci=lambda d: d["ci_low"].round(3).astype(str) + " – " + d["ci_high"].round(3).astype(str))
        .pivot(index="strategy", columns="metric", values="ci")
        .reset_index()
        .sort_values("strategy")
    )
    return mean_tbl, wide, ci_df

# ============================================================
# RUN
# ============================================================
start = time.time()
results = evaluate_asqa(sample_frac=SAMPLE_FRAC, n_docs=N_DOCS)
dur = time.time()-start
print(f"\nDone. Evaluated {len(results)} (strategy,example) pairs in {dur/60:.1f} min.")

# Save raw per-example metrics + answers
results.to_csv("asqa_results_per_example.csv", index=False)
display(results.head())

# Summaries + CIs
mean_tbl, ci_wide, ci_long = summarize_with_ci(results)
print("\n=== Means ===")
display(mean_tbl)

print("\n=== 95% bootstrap CIs (wide) ===")
display(ci_wide)

# Save CI tables too
ci_long.to_csv("asqa_results_bootstrap_cis_long.csv", index=False)
ci_wide.to_csv("asqa_results_bootstrap_cis_wide.csv", index=False)
print("\nSaved: asqa_results_per_example.csv, asqa_results_bootstrap_cis_[wide|long].csv")


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/25.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/688 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.43G [00:00<?, ?B/s]

Some weights of the model checkpoint at roberta-large-mnli were not used when initializing RobertaForSequenceClassification: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Few-shot example IDs: ['5489717211503600508', '8630912480840635425', '8655001621666816912']

=== Strategy: chain-of-thought on 948 items ===


chain-of-thought: 100%|██████████| 948/948 [14:39<00:00,  1.08it/s]



=== Strategy: chain-citation on 948 items ===


chain-citation: 100%|██████████| 948/948 [14:45<00:00,  1.07it/s]



=== Strategy: chain-quote on 948 items ===


chain-quote: 100%|██████████| 948/948 [16:13<00:00,  1.03s/it]


Done. Evaluated 2844 (strategy,example) pairs in 45.6 min.





Unnamed: 0,strategy,id,precision_overlap,recall_overlap,precision_nli,recall_nli,answer
0,chain-of-thought,-7013890438520559398,0.0,0.0,1.0,1.0,Pelé has the highest goals in men's world inte...
1,chain-of-thought,7089015503030534342,1.0,0.666667,0.0,0.0,Paul Simon [1][2][3]\n\nCitations:\n1. The Sou...
2,chain-of-thought,8793099883447006698,1.0,1.0,0.0,0.0,"The first iPhone was released on June 29, 2007..."
3,chain-of-thought,-881464876144297194,0.5,0.25,0.0,0.0,Domhnall Gleeson [1][3]\n\nCitations:\n1. Orde...
4,chain-of-thought,1650309494326541834,0.0,0.0,0.0,0.0,There are six state parks in Virginia in 1936....



=== Means ===


Unnamed: 0,strategy,precision_overlap,recall_overlap,precision_nli,recall_nli
0,chain-citation,0.455767,0.462781,0.227256,0.22358
1,chain-of-thought,0.527409,0.505083,0.243835,0.25875
2,chain-quote,0.485338,0.513921,0.25259,0.258066



=== 95% bootstrap CIs (wide) ===


metric,strategy,precision_nli,precision_overlap,recall_nli,recall_overlap
0,chain-citation,0.205 – 0.252,0.43 – 0.485,0.202 – 0.248,0.435 – 0.489
1,chain-of-thought,0.217 – 0.269,0.501 – 0.553,0.233 – 0.284,0.48 – 0.53
2,chain-quote,0.23 – 0.278,0.46 – 0.51,0.235 – 0.284,0.489 – 0.539



Saved: asqa_results_per_example.csv, asqa_results_bootstrap_cis_[wide|long].csv


In [None]:
def build_rag_block(row: pd.Series, n_docs: N_DOCS) -> str:
    lines = ['Retrieved Passages (top {}/100):'.format(n_docs)]
    for i, p in enumerate(row['docs'][:n_docs], start=1):
        title = p.get('title','N/A')
        text  = p.get('text','').replace('\n',' ').strip()
        snippet = ' '.join(text.split()[:100])
        lines.append(f"{i}. Title: {title}\n   {snippet}\n")
    return '\n'.join(lines)
def extract_cited_titles(llm_output: str) -> list:
    parts = re.split(r"(?i)Citations:", llm_output)
    if len(parts) < 2:
        return []
    cited = []
    for line in parts[1].splitlines():
        m = re.match(r"\s*(\d+)\.\s+(.*)", line)
        if m:
            cited.append(m.group(2).strip().rstrip('.'))
    return cited

def llama_generate(
    prompt: str,
    max_tokens: int = 512,
    temperature: float = 0.0,
    top_p: float = 1.0
) -> str:
    """
    Send prompt to Llama 3.3 8B via Groq and return the generated text.
    """
    # 1. Call the chat/completion endpoint
    chat_completion = client.chat.completions.create(
        model="llama3-8b-8192",           # or the exact Llama 3.3 8B model name
        messages=[{"role": "user", "content": prompt}],
        max_tokens=max_tokens,
        temperature=temperature,
        top_p=top_p,
        # You can also pass stop_sequences=[...] if you want to force a stop.
    )

    # 2. Extract the generated text via attribute access
    #    (NOT via chat_completion["choices"][0]["message"]["content"])
    generated_text = chat_completion.choices[0].message.content
    return generated_text.strip()
def zero_shot_explicit_rag_prompt(row: pd.Series, n_docs: N_DOCS) -> str:
    q   = row['question']
    rag = build_rag_block(row, n_docs)
    return f"""
Question:
{q}
Answer using these passages if relevant.Whenever you make a factual statement, immediately append “[<k>]” to reference which passage supports it.
{rag}

...""".strip()

def few_shot_qa_explicit_rag_prompt(row: pd.Series,
                                    example_ids: list,
                                    n_docs: N_DOCS) -> str:
    prefix = []
    for ex_id in example_ids:
        ex = df_all[df_all['id']==ex_id].iloc[0]
        ex_q = ex['question']
        ex_a = ex['reference_answer'].split('\n')[0]
        # pick top-1 passage title as demo
        demo_title = ex['docs'][0].get('title','N/A')
        block = f"""
Example:
Question: {ex_q}
Answer: {ex_a}

Citations:
  1. {demo_title}
"""
        prefix.append(block.strip())
    demo_block = '\n\n'.join(prefix)
    q, rag = row['question'], build_rag_block(row, n_docs)
    return f"""
The following are a few examples on how to answer questions

{demo_block}

===== Now your turn: =====

Question: {q}
Answer using these passages if relevant.Whenever you make a factual statement, immediately append “[<k>]” to reference which passage supports it.
Below are examples of how to answer and cite:
{rag}

...""".strip()

def cot_standard_explicit_rag_prompt(row: pd.Series, n_docs: N_DOCS) -> str:
    q, rag = row['question'], build_rag_block(row, n_docs)
    return f"""
You are an expert QA system. Use retrieved passages to reason step by step, then answer concisely.
Finally append a literal 'Citations:' list of page titles.

Steps:
  1. Identify key concepts.
  2. Search for evidence in the passages.
  3. Construct answer, then list used titles under 'Citations:'.

Question:
{q}
Answer using these passages if relevant.Whenever you make a factual statement, immediately append “[<k>]” to reference which passage supports it.

{rag}

...""".strip()

def cot_punishment_explicit_rag_prompt(row: pd.Series, n_docs: N_DOCS) -> str:
    q, rag = row['question'], build_rag_block(row, n_docs)
    return f"""
You are an AI assistant evaluated for factual accuracy.
Incorrect or hallucinated citations will be penalized.
Only cite sources directly from the passages.

Steps:
  1. Identify key concepts.
  2. Find evidence in passages.
  3. Construct answer and list used titles under 'Citations:'.

Question:
{q}
Answer using these passages if relevant.Whenever you make a factual statement, immediately append “[<k>]” to reference which passage supports it.
{rag}

""".strip()

def cot_praise_explicit_rag_prompt(row: pd.Series, n_docs: N_DOCS) -> str:
    q, rag = row['question'], build_rag_block(row, n_docs)
    return f"""
You are known for high-quality answers and citation accuracy.
Provide accurate citations to build trust.

Steps:
  1. Identify key concepts.
  2. Search evidence in passages.
  3. Construct answer and list used titles under 'Citations:'.

Question:
{q}
Answer using these passages if relevant.Whenever you make a factual statement, immediately append “[<k>]” to reference which passage supports it.
{rag}

...""".strip()

def cot_academic_explicit_rag_prompt(row: pd.Series, n_docs: N_DOCS) -> str:
    q, rag = row['question'], build_rag_block(row, n_docs)
    return f"""
You are a rising academic in AI; your reputation depends on accurate, well-supported citations.
Misattributed or missing citations damage credibility—cite with care.

Steps:
  1. Identify key concepts.
  2. Find evidence in passages.
  3. Construct answer and list used titles under 'Citations:'.

Question:
{q}
Answer using these passages if relevant.Whenever you make a factual statement, immediately append “[<k>]” to reference which passage supports it.
{rag}

...""".strip()

def chain_of_citation_prompt(row: pd.Series, n_docs: N_DOCS) -> str:
    """
    """
    q   = row['question']
    rag = build_rag_block(row, n_docs)
    return f"""
You are an expert QA system. Use the following retrieved passages to reason in a chain‐of‐thought, citing as you go. Use the following steps:
    • Reason step by step.
    • After each step, append “[<k>]” to show which passage supports it.
    • Finish with a concise Answer and a numbered Citations list.

Question:
{q}

{rag}


…""".strip()


def chain_of_quote_prompt(row: pd.Series, n_docs: N_DOCS) -> str:
    """
    """
    q   = row['question']
    rag = build_rag_block(row, n_docs)
    return f"""
You are an expert QA system. Use these passages to build a chain of reasoning. At each step:
  • Reason step by step.
  • For each step, include a direct quotation from the supporting passage in quotes,
    then append “[<k>]” to cite it.
  • Finish with final Answer + Citations list.

Question:
{q}

{rag}

…""".strip()


def chain_of_verification_prompt(row: pd.Series, n_docs: N_DOCS) -> str:
    """
    """
    q   = row['question']
    rag = build_rag_block(row, n_docs)
    return f"""
You are an AI assistant. Use chain‐of‐thought with explicit verification for each step. Use the following steps
    • For each reasoning step, state the claim.
    • Then verify it by checking the passages:
      – Write “Verified [<k>]” if the passage supports it.
      – Otherwise write “Not verified [<k>]”.
    • End with final Answer + Citations.
Question:
{q}

{rag}

…""".strip()


def self_verification_prompt(row: pd.Series, n_docs: N_DOCS) -> str:
    """
    """
    q   = row['question']
    rag = build_rag_block(row, n_docs)
    return f"""
You are an AI assistant. First produce an answer with citations, then verify your own work with the following steps:
    1) Generate your answer with citations.
    2) Then review your own answer: for each citation,
       check whether it truly supports the statement; if not, remove it.
    3) Present the cleaned‐up answer and final Citations.
Question:
{q}

{rag}

…""".strip()

tasks = [
    ('zero-shot', zero_shot_explicit_rag_prompt),
    ('few-shot', lambda r: few_shot_qa_explicit_rag_prompt(r, example_ids=[df_all['id'].iloc[0], df_all['id'].iloc[1], df_all['id'].iloc[2]])),
    ('cot-standard', cot_standard_explicit_rag_prompt),
    ('cot-punishment', cot_punishment_explicit_rag_prompt),
    ('cot-praise', cot_praise_explicit_rag_prompt),
    ('cot-academic', cot_academic_explicit_rag_prompt),
    ('chain-citation', chain_of_citation_prompt),
    ('chain-quote',    chain_of_quote_prompt),
    ('chain-verify',   chain_of_verification_prompt),
    ('self-verify',    self_verification_prompt)
]


In [None]:
pip install groq



In [None]:
from groq import Groq


client = Groq(
    api_key=""
)


In [None]:

FORMAT_RULES = (
    "Answer using only the passages below. Be concise and precise. "
    "In the FINAL ANSWER, append a citation like [k] after every sentence, "
    "where k is the number of the supporting passage shown (1..{n}). "
    "If more than one passage supports a sentence, use separate brackets, e.g., [1][3]. "
    "Use square brackets only. Do not restate the question. "
    "Avoid unsupported generalizations (e.g., 'widely recognized', 'unlikely')."
)

def _prefix(strategy_text: str, n_docs: int) -> str:
    return (
        f"{strategy_text}\n"
        + FORMAT_RULES.format(n=n_docs)
        + "\n\nQuestion:\n{q}\n\nRetrieved Passages (numbered):\n{rag}\n"
        "\nFINAL ANSWER:\n"
    )


def zero_shot_explicit_rag_prompt(row, n_docs: int = 5) -> str:
    q   = row["question"]
    rag = build_rag_block(row, n_docs)
    strategy = "Zero-shot answering with inline evidence."
    return _prefix(strategy, n_docs).format(q=q, rag=rag)

def cot_standard_explicit_rag_prompt(row, n_docs: int = 5) -> str:
    q, rag = row["question"], build_rag_block(row, n_docs)
    strategy = (
        "Reason briefly using the passages, then output only the FINAL ANSWER. "
        "Do not show your reasoning."
    )
    return _prefix(strategy, n_docs).format(q=q, rag=rag)

def chain_of_citation_prompt(row, n_docs: int = 5) -> str:
    q, rag = row["question"], build_rag_block(row, n_docs)
    strategy = (
        "Use the passages to reason step by step, appending [k] to each factual step. "
        "THEN provide the FINAL ANSWER with [k] after every sentence."
    )
    return _prefix(strategy, n_docs).format(q=q, rag=rag)

def chain_of_quote_prompt(row, n_docs: int = 5) -> str:
    q, rag = row["question"], build_rag_block(row, n_docs)
    strategy = (
        "Build your reasoning step by step. At each step, include a short direct quote "
        "from the supporting passage and append [k]. THEN provide the FINAL ANSWER with [k] "
        "after every sentence."
    )
    return _prefix(strategy, n_docs).format(q=q, rag=rag)

def chain_of_verification_prompt(row, n_docs: int = 5) -> str:
    q, rag = row["question"], build_rag_block(row, n_docs)
    strategy = (
        "Draft your answer, then verify each claim against the passages. "
        "Keep only claims that are directly supported. "
        "Finally output the FINAL ANSWER with [k] after every sentence."
    )
    return _prefix(strategy, n_docs).format(q=q, rag=rag)

def self_verification_prompt(row, n_docs: int = 5) -> str:
    q, rag = row["question"], build_rag_block(row, n_docs)
    strategy = (
        "First, produce an answer with inline citations. Then self-check each sentence: "
        "if its citation does not support it, revise or remove it. "
        "Output only the FINAL ANSWER."
    )
    return _prefix(strategy, n_docs).format(q=q, rag=rag)

def motivation_prompt(row, n_docs: int = 5) -> str:
    q, rag = row["question"], build_rag_block(row, n_docs)
    strategy = (
        "Your output will be evaluated for factual accuracy and proper attribution. "
        "Incorrect, missing, or hallucinated citations will be penalized."
    )
    return _prefix(strategy, n_docs).format(q=q, rag=rag)

def role_prompt(row, n_docs: int = 5) -> str:
    q, rag = row["question"], build_rag_block(row, n_docs)
    strategy = (
        "You are a researcher specializing in NLP/ML. Use the passages to answer concisely."
    )
    return _prefix(strategy, n_docs).format(q=q, rag=rag)


def few_shot_qa_explicit_rag_prompt(row, example_ids, n_docs: int = 5) -> str:

    demos = []
    for ex_id in example_ids:
        ex = df_all[df_all["id"] == ex_id].iloc[0]
        ex_q = ex["question"]

        demos.append(
            f"Question: {ex_q}\nAnswer: Fact A. [1] Fact B. [2]\n"
            "Citations:\n1. {t1}\n2. {t2}\n".format(
                t1=ex["docs"][0].get("title","Passage 1") if len(ex["docs"])>=1 else "Passage 1",
                t2=ex["docs"][1].get("title","Passage 2") if len(ex["docs"])>=2 else "Passage 2",
            )
        )
    demo_block = "\n\n".join(demos)

    q, rag = row["question"], build_rag_block(row, n_docs)
    header = "Few-shot answering with inline evidence.\n" + FORMAT_RULES.format(n=n_docs)
    return (
        header + "\n\nExamples:\n" + demo_block +
        f"\n\nQuestion:\n{q}\n\nRetrieved Passages (numbered):\n{rag}\n\nFINAL ANSWER:\n"
    )


In [None]:
# Extract only the 'Answer:' portion before 'Citations:'
ANSWER_PAT = re.compile(r'(?is)answer\\s*:\\s*(.*?)(?:\\n\\s*citations\\s*:\\s*|$)')

def extract_answer_text(output: str) -> str:
    m = ANSWER_PAT.search(output)
    return m.group(1).strip() if m else output.split("Citations:")[0].strip()

# Split into sentences
SENT_SPLIT = re.compile(r'(?<=[.!?])\\s+(?=[A-Z0-9])')
def segment_statements(output: str) -> list:
    text = extract_answer_text(output)
    parts = SENT_SPLIT.split(text)
    return [p.strip() for p in parts if p.strip()]

# Accept [n], 【n】, or (n)
# one capture group, any “opening bracket” char
# Accept multiple styles of brackets
CITE_RE = re.compile(
    r"[［\[\(【〔]\s*(\d+)\s*[］\]\)】〕]"  # any variant of brackets with a number inside
)

def align_citations(output: str) -> dict:
    stmts = segment_statements(output)
    cmap = {}
    for i, s in enumerate(stmts):
        # use a set to deduplicate repeated citations like [1][1]
        cmap[i] = sorted({int(n) for n in CITE_RE.findall(s)})
    return cmap





In [None]:
# Use roberta-large-mnli for entailment probability
from transformers import AutoTokenizer, AutoModelForSequenceClassification
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
tok = AutoTokenizer.from_pretrained("roberta-large-mnli")
clf = AutoModelForSequenceClassification.from_pretrained("roberta-large-mnli").to(DEVICE)
clf.eval()
_softmax = torch.nn.Softmax(dim=1)
ENTAIL_ID = [i for i, lab in clf.config.id2label.items() if "ENTAIL" in lab.upper()][0]
from functools import lru_cache

@lru_cache(maxsize=100_000)
def cached_entail(premise: str, hypothesis: str, thr: float = 0.5) -> bool:
    prob = entailment_prob(premise, hypothesis)
    return prob >= thr

def entailment_prob(premise: str, hypothesis: str, max_length=384) -> float:
    if not premise or not hypothesis:
        return 0.0
    inputs = tok(premise, hypothesis, return_tensors="pt", truncation=True, max_length=max_length).to(DEVICE)
    with torch.no_grad():
        logits = clf(**inputs).logits
    probs = _softmax(logits).squeeze(0).detach().cpu().numpy()
    return float(probs[ENTAIL_ID])

def check_entailment(premise: str, hypothesis: str, thr: float = 0.5) -> bool:
    return entailment_prob(premise, hypothesis) >= thr

def citation_recall(row, output: str) -> float:
    """% of cited sentences whose union-of-cited passages entails the sentence."""
    pairs = [(i,s) for i,s in enumerate(segment_statements(output)) if align_citations(output).get(i)]
    if not pairs:
        return 0.0
    cmap = align_citations(output)
    supported = 0
    for i, stmt in pairs:
        idxs = [k for k in cmap[i] if 1 <= k <= len(row["docs"])]
        premise = " ".join((row["docs"][k-1]["text"] or "") for k in idxs)
        if cached_entail(premise, stmt):
            supported += 1
    return supported / len(pairs)

def citation_precision(row, output: str) -> float:
    """
    ALCE: if the union of citations does not entail, all citations are irrelevant.
    Else, a citation is irrelevant if it alone does not entail and the rest still entail.
    """
    stmts = segment_statements(output)
    cmap  = align_citations(output)
    total = 0
    relevant = 0
    for i, stmt in enumerate(stmts):
        idxs = [k for k in cmap.get(i, []) if 1 <= k <= len(row["docs"])]
        if not idxs:
            continue
        Ci_texts = [row["docs"][k-1]["text"] or "" for k in idxs]
        union_entails = cached_entail(" ".join(Ci_texts), stmt)
        if not union_entails:
            total += len(idxs)
            continue
        for pos, k in enumerate(idxs):
            total += 1
            alone_entails = cached_entail(row["docs"][k-1]["text"] or "", stmt)
            rest_texts = Ci_texts[:pos] + Ci_texts[pos+1:]
            rest_entails = cached_entail(" ".join(rest_texts), stmt) if rest_texts else False
            if not ((not alone_entails) and rest_entails):
                relevant += 1
    return relevant / total if total > 0 else 0.0


Some weights of the model checkpoint at roberta-large-mnli were not used when initializing RobertaForSequenceClassification: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [None]:
def generate_with_enforcement(builder, row, n_docs=N_DOCS):
    out = llama_generate(builder(row), max_tokens=MAX_TOKENS, temperature=TEMPERATURE)
    # simple check for [number] in the final answer
    answer_text = extract_answer_text(out)
    if not re.search(r'\[\s*\d+\s*\]', answer_text):
        # ask the model to rewrite with citations
        fix_prompt = f"""Rewrite the FINAL ANSWER below so that EVERY sentence ends with a numeric citation [k] (k∈1..{n_docs}).
Use only square brackets. Do not change the content.

Passages:
{build_rag_block(row, n_docs)}

Original:
{out}
"""
        out = llama_generate(fix_prompt, max_tokens=MAX_TOKENS, temperature=0.0)
    return out


In [None]:
# Use the union of gold pages from df_all["wikipages"] for overlap
# --- helper: map inline index -> passage title ---
def titles_cited_by_indices(row, output: str):
    """
    Return the list of passage titles that were referenced inline.
    It reads the [k] markers in each sentence, pulls the kth passage
    from row['docs'], and returns their 'title' fields.
    """
    cmap = align_citations(output)
    indices = sorted({k for ks in cmap.values() for k in ks
                      if 1 <= k <= len(row["docs"])})
    return [row["docs"][i-1]["title"] for i in indices]

def overlap_metrics(row, output: str):
    pred_titles = {t.lower().strip() for t in titles_cited_by_indices(row, output)}
    gold_titles = {w["title"].lower().strip() for w in (row.get("wikipages") or [])}
    tp = len(pred_titles & gold_titles)
    return (tp/len(pred_titles) if pred_titles else 0.0,
            tp/len(gold_titles)  if gold_titles  else 0.0)

# Build strategies (few-shot uses a fixed pool)
import os, json, random

EXAMPLE_IDS_PATH = "fewshot_example_ids.json"
NUM_EXAMPLES = 3
SAMPLE_FRAC = 0.2  # your eval fraction

# 1) Build the pool for examples
#    Prefer non-dev (e.g., train) if you loaded it; otherwise use df_all but we’ll exclude them from eval.
pool_ids = df_all["id"].tolist()

# 2) Load fixed examples if they exist; else select once and save
if os.path.exists(EXAMPLE_IDS_PATH):
    with open(EXAMPLE_IDS_PATH, "r") as f:
        example_ids = json.load(f)
    # keep only those that still exist in df_all
    example_ids = [i for i in example_ids if i in pool_ids]
else:
    random.seed(2025)  # deterministic
    example_ids = random.sample(pool_ids, k=min(NUM_EXAMPLES, len(pool_ids)))
    with open(EXAMPLE_IDS_PATH, "w") as f:
        json.dump(example_ids, f)

print("Using fixed few-shot example IDs:", example_ids)

# 3) Build your evaluation sample EXCLUDING the few-shot examples
sample_df = df_all.loc[~df_all["id"].isin(example_ids)].sample(frac=SAMPLE_FRAC, random_state=2025)

# 4) Strategies (unchanged)
strategies = [
    ('zero-shot',        zero_shot_explicit_rag_prompt),
    ('few-shot',         lambda r: few_shot_qa_explicit_rag_prompt(r, example_ids)),
    ('chain-of-thought', cot_standard_explicit_rag_prompt),
    ('chain-citation',   chain_of_citation_prompt),
    ('chain-quote',      chain_of_quote_prompt),
    ('chain-verify',     chain_of_verification_prompt),
    ('self-verify',      self_verification_prompt),
    ('motivation',       motivation_prompt),
    ('role',             role_prompt),
]



# Evaluate over a sample of the dev set
results = []
sample_df = df_all.sample(frac=0.1, random_state=2025)
for name, builder in tqdm(strategies):
    for _, row in tqdm(sample_df.iterrows(), total=len(sample_df)):
        out = generate_with_enforcement(builder, row)
        prec_o, rec_o = overlap_metrics(row, out)
        rec_n  = citation_recall(row, out)
        prec_n = citation_precision(row, out)
        results.append({
            "strategy": name,
            "precision_overlap": prec_o,
            "recall_overlap":    rec_o,
            "precision_nli":     prec_n,
            "recall_nli":        rec_n,
        })

metrics = pd.DataFrame(results).groupby("strategy")[["precision_overlap","recall_overlap","precision_nli","recall_nli"]].mean()
print(metrics)


In [None]:
# =========================
# One-sample, all strategies: question + prompts + answers + overlap & NLI + debug
# =========================
import os, json, random, re
import pandas as pd
from IPython.display import display

# --- config ---
try:
    N_DOCS
except NameError:
    N_DOCS = 5  # must match what you show the model

SHOW_PROMPTS = True         # set False to hide full prompts
TRUNCATE_PROMPT = 1200      # limit printed prompt length (avoid flooding)
EXAMPLE_IDS_PATH = "fewshot_example_ids.json"
NUM_EXAMPLES = 3
random.seed(2025)

# --- strategies (reusing your builders) ---
STRATEGY_BUILDERS = {
    "zero-shot":        zero_shot_explicit_rag_prompt,
    "few-shot":         few_shot_qa_explicit_rag_prompt,   # needs example_ids
    "chain-of-thought": cot_standard_explicit_rag_prompt,
    "chain-citation":   chain_of_citation_prompt,
    "chain-quote":      chain_of_quote_prompt,
    "chain-verify":     chain_of_verification_prompt,
    "self-verify":      self_verification_prompt,
    "motivation":       motivation_prompt,
    "role":             role_prompt,
}

# --- small cleaners/niceties (optional but helpful) ---
QUESTION_LINE_RE = re.compile(r"(?im)^\s*question\s*:\s*")
ANSWER_LINE_RE   = re.compile(r"(?im)^\s*answer\s*:\s*")
CITATIONS_HDR_RE = re.compile(r"(?im)^\s*Citations\s*:\s*$")

def normalize_final_answer(output: str, row=None) -> str:
    """Keep only the final answer text (drop 'Question:' echo etc.)."""
    if not output:
        return output
    m = ANSWER_LINE_RE.search(output)
    if m:
        output = output[m.end():].strip()
    lines = [ln for ln in output.splitlines() if not QUESTION_LINE_RE.match(ln)]
    out = "\n".join(lines).strip()
    if row is not None and isinstance(row.get("question"), str):
        q = row["question"].strip()
        out = "\n".join(ln for ln in out.splitlines() if ln.strip() != q).strip()
    return out

def ensure_citations_block(row, output: str) -> str:
    """Append a 'Citations:' list mapped from used [k] if none present."""
    if not output or CITATIONS_HDR_RE.search(output):
        return output
    cmap = align_citations(output)
    used = sorted({idx for ids in cmap.values() for idx in ids
                   if 1 <= idx <= len(row["docs"])})
    if not used:
        return output
    lines = ["", "Citations:"]
    for i, k in enumerate(used, start=1):
        title = row["docs"][k-1].get("title", f"Passage {k}")
        lines.append(f"{i}. {title}")
    return output.rstrip() + "\n" + "\n".join(lines) + "\n"

# --- overlap helpers (title-based, like ASQA side) ---
def titles_cited_by_indices(row, output: str):
    cmap = align_citations(output)
    idxs = sorted({idx for ids in cmap.values() for idx in ids
                   if 1 <= idx <= len(row["docs"])})
    return [row["docs"][i-1].get("title", "") for i in idxs]

def overlap_metrics(row, output: str):
    pred_titles = {t.lower().strip() for t in titles_cited_by_indices(row, output) if t}
    gold_titles = {(w.get("title","") or "").lower().strip()
                   for w in (row.get("wikipages") or [])}
    tp = len(pred_titles & gold_titles)
    prec = tp / len(pred_titles) if pred_titles else 0.0
    rec  = tp / len(gold_titles) if gold_titles else 0.0
    return prec, rec

# --- adapter: make builders compatible with generate_with_enforcement(builder,row,n_docs) ---
def make_builder_adapter(name, builder, example_ids, n_docs):
    if name == "few-shot":
        return lambda r, n_docs_param=None: builder(r, example_ids, n_docs=(n_docs_param or n_docs))
    else:
        return lambda r, n_docs_param=None: builder(r, n_docs=(n_docs_param or n_docs))

# --- debug printer for a single (row, answer) ---
def show_overlap_and_nli_debug(row, answer: str):
    sents = segment_statements(answer)
    cmap  = align_citations(answer)
    pred_titles = titles_cited_by_indices(row, answer)
    gold_titles = [(w.get("title","") or "") for w in (row.get("wikipages") or [])]

    print("  Sentences & indices:")
    for i, s in enumerate(sents):
        print(f"   {i+1:>2}. {s}")
        print(f"       cites: {cmap.get(i, [])}")

    po, ro = overlap_metrics(row, answer)
    pn     = citation_precision(row, answer)
    rn     = citation_recall(row, answer)

    print("  Predicted titles:", pred_titles)
    print("  Gold titles (first 10):", gold_titles[:10])
    print(f"  overlap→ precision={po:.3f}  recall={ro:.3f}")
    print(f"  NLI    → precision={pn:.3f}  recall={rn:.3f}")

# --- main: run all strategies on ONE row and print everything nicely ---
def run_all_strategies_with_debug(row, example_ids, n_docs: int = N_DOCS):
    records = []
    print("QUESTION:", row["question"])
    for name, builder in STRATEGY_BUILDERS.items():
        print("\n" + "="*28, name.upper(), "="*28)
        adapter = make_builder_adapter(name, builder, example_ids, n_docs)

        # build prompt (for logging/inspection)
        prompt_text = adapter(row, n_docs_param=n_docs)
        if SHOW_PROMPTS:
            if len(prompt_text) > TRUNCATE_PROMPT:
                preview = prompt_text[:TRUNCATE_PROMPT] + " …"
            else:
                preview = prompt_text
            print("\n[Prompt]\n", preview)

        # generate, normalize, ensure citations list
        out = generate_with_enforcement(adapter, row, n_docs=n_docs)
        out = normalize_final_answer(out, row)
        out = ensure_citations_block(row, out)

        print("\n[Answer]\n", out.strip())
        print("\n[Debug]")
        show_overlap_and_nli_debug(row, out)

        # metrics for the row/strategy
        prec_o, rec_o = overlap_metrics(row, out)
        rec_n  = citation_recall(row, out)
        prec_n = citation_precision(row, out)

        records.append({
            "strategy": name,
            "question": row["question"],
            "prompt":   prompt_text,
            "answer":   out,
            "precision_overlap": prec_o,
            "recall_overlap":    rec_o,
            "precision_nli":     prec_n,
            "recall_nli":        rec_n,
        })
    return pd.DataFrame.from_records(records)

# =========================
# Choose fixed few-shot examples, pick ONE eval row, run everything
# =========================
pool_ids = df_all["id"].tolist()

def has_two_docs(qid):
    ex = df_all[df_all["id"] == qid].iloc[0]
    return isinstance(ex.get("docs", []), list) and len(ex["docs"]) >= 2

# load or create fixed examples (prefer IDs with ≥2 passages)
if os.path.exists(EXAMPLE_IDS_PATH):
    with open(EXAMPLE_IDS_PATH, "r") as f:
        example_ids = json.load(f)
    example_ids = [i for i in example_ids if i in pool_ids]
else:
    candidates = [i for i in pool_ids if has_two_docs(i)]
    base = candidates if len(candidates) >= NUM_EXAMPLES else pool_ids
    example_ids = random.sample(base, k=min(NUM_EXAMPLES, len(base)))
    with open(EXAMPLE_IDS_PATH, "w") as f:
        json.dump(example_ids, f)

print("Few-shot example IDs:", example_ids)

# pick one eval row that is NOT one of the few-shot examples
row = df_all.loc[~df_all["id"].isin(example_ids)].sample(1, random_state=2025).iloc[0]
print("Selected eval row ID:", row["id"])

# run all strategies on this row
df_one = run_all_strategies_with_debug(row, example_ids, n_docs=N_DOCS)

# summary table at the end
print("\n" + "="*20 + " SUMMARY " + "="*20)
display(df_one[["strategy","precision_overlap","recall_overlap","precision_nli","recall_nli"]])

# save everything (question + prompts + answers + metrics)
df_one.to_csv("asqa_single_row_all_strategies_with_debug.csv", index=False)
print("Saved: asqa_single_row_all_strategies_with_debug.csv")


In [None]:
print(align_citations(out))     # should now show {0: [1], 1: [1], 2: [1]}


In [None]:
probe = sample_df.iloc[0]
out   = generate_with_enforcement(strategies[0][1], probe)
print(align_citations(out))                      # should map sentences → [k]
for k in align_citations(out)[0]:
    print(probe['docs'][k-1]['title'])           # see what was cited


In [None]:
probe_row = sample_df.iloc[0]
out = generate_with_enforcement(strategies[0][1], probe_row)
print("RAW:\n", repr(out))          # shows escape codes
