In [26]:
import ujson as json
import re
import torch
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import random

In [27]:
with open("../data/train-claims.json") as f:
    train = json.load(f)
with open("../data/dev-claims.json") as f:
    dev = json.load(f)
with open("../data/evidence.json") as f:
    evidence = json.load(f)
with open("../data/test-evidence-faiss.json") as f:
    test_faiss = json.load(f)

In [28]:
def build_fewshot_prompt():

    FEWSHOT = """You are a fact verification assistant.
Your task is to choose the correct classification label for a given claim based on the supporting evidence statements.
Possible labels:
- SUPPORTS: All evidence supports the claim as factual.
- REFUTES: All evidence contradicts the claim.
- NOT_ENOUGH_INFO: The evidence is insufficient, irrelevant, or off-topic.
- DISPUTED: Evidence is conflicting or controversial.

Output only the label."""

    label_buckets = {"SUPPORTS": [], "REFUTES": [], "NOT_ENOUGH_INFO": [], "DISPUTED": []}
    for cid, obj in train.items():
        label = obj["claim_label"]
        if label in label_buckets and len(label_buckets[label]) < 3:
            sents = [evidence[eid] for eid in obj["evidences"] if eid in evidence]
            if len(sents) >= 1:
                label_buckets[label].append((obj["claim_text"], sents[:2]))

    # Flatten all examples into one string
    for label, examples in label_buckets.items():
        for claim, sents in examples:
            FEWSHOT += f"\n\n### EXAMPLE\nClaim: {claim}\nStatements:"
            for s in sents:
                FEWSHOT += f"\n- \"{s}\""
            FEWSHOT += f"\nClass: {label}"

    return FEWSHOT

In [29]:
def build_fewshot_prompt_cot():
    FEWSHOT = """You are a fact verification assistant.
Your task is to choose the correct classification label for a given claim based on the supporting evidence statements.
Possible labels:
- SUPPORTS: All evidence supports the claim as factual.
- REFUTES: All evidence contradicts the claim.
- NOT_ENOUGH_INFO: The evidence is insufficient, irrelevant, or off-topic.
- DISPUTED: Evidence is conflicting or controversial.

Before making a decision, explain your reasoning based on the statements.
Output the reasoning followed by the label."""

    label_buckets = {"SUPPORTS": [], "REFUTES": [], "NOT_ENOUGH_INFO": [], "DISPUTED": []}
    for cid, obj in train.items():
        label = obj["claim_label"]
        if label in label_buckets and len(label_buckets[label]) < 2:
            sents = [evidence[eid] for eid in obj["evidences"] if eid in evidence]
            if len(sents) >= 1:
                label_buckets[label].append((obj["claim_text"], sents[:2]))

    for label, examples in label_buckets.items():
        for claim, sents in examples:
            FEWSHOT += f"\n\n### EXAMPLE\nClaim: {claim}\nStatements:"
            for s in sents:
                FEWSHOT += f"\n- \"{s}\""
            FEWSHOT += f"\nReasoning: [explain how the evidences relates to the claim.]"  # LLM completes this
            FEWSHOT += f"\nClass: {label}"

    return FEWSHOT


In [30]:
FEWSHOT = build_fewshot_prompt() # chain of thought or normal

In [31]:
# Using Mistral-7B
MODEL = "mistralai/Mistral-7B-Instruct-v0.2"
llm = pipeline(
    "text-generation",
    model=AutoModelForCausalLM.from_pretrained(
        MODEL,
        load_in_4bit=True,
        device_map="cuda",
        trust_remote_code=True
    ),
    tokenizer=AutoTokenizer.from_pretrained(MODEL),
    max_new_tokens=64,
    temperature=0.0 # labelling task, avoid random
)



The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.
Loading checkpoint shards: 100%|██████████| 3/3 [00:47<00:00, 15.86s/it]
Device set to use cuda


In [32]:
PROMPT = FEWSHOT + """
### TASK
Claim: {claim}
Statements:
{evidence_lines}
Class:
"""

def classify_claim(claim: str, evidence_sents: list[str]) -> str:
    ev_lines = "\n".join(f"- \"{s}\"" for s in evidence_sents[:5]) # 5 evidences
    prompt   = PROMPT.format(claim=claim, evidence_lines=ev_lines)
    out      = llm(prompt)[0]["generated_text"].split("Class:")[-1].strip()
    match    = re.search(r"(SUPPORTS|REFUTES|NOT_ENOUGH_INFO|DISPUTED)", out.upper())
    return match.group(1).upper() if match else "NOT_ENOUGH_INFO"

In [33]:
out = {}
for cid, obj in tqdm(dev.items(), desc="LLM classify"):
    claim = obj["claim_text"]
    ev_sents = [evidence[eid] for eid in obj["evidences"]]
    label = classify_claim(claim, ev_sents)
    out[cid] = {"claim_label": label, "evidences": obj["evidences"]}




LLM classify:   0%|          | 0/154 [00:00<?, ?it/s]

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
LLM classify:   1%|▏         | 2/154 [00:37<48:44, 19.24s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
LLM classify:   2%|▏         | 3/154 [00:40<29:54, 11.88s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
LLM classify:   3%|▎         | 4/154 [00:45<23:08,  9.26s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
LLM classify:   3%|▎         | 5/154 [00:49<18:28,  7.44s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
LLM classify:   4%|▍         | 6/154 [00:54<16:23,  6.64s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
LLM classify:   5%|▍         | 7/154 [00:59<14:56,  6.10s/it]Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
LLM classify:   5%|▌         | 8/154 [01:03<12:41,  5.21s/it]Setting `pad_token_id` to `eo

KeyboardInterrupt: 

In [None]:
from collections import Counter

label_counts = Counter([entry["claim_label"] for entry in out.values()])

print("🔢 Label counts:")
for label, count in label_counts.items():
    print(f"{label:>20}: {count}")

🔢 Label counts:
            SUPPORTS: 72
     NOT_ENOUGH_INFO: 58
            DISPUTED: 17
             REFUTES: 7


In [None]:
with open("../data/dev-predicted-mistral.json", "w") as f:
    json.dump(out, f, indent=2)
print("✓ ")

✓ 


In [None]:
out = {}
for cid, obj in tqdm(test_faiss.items(), desc="LLM classify"):
    claim = obj["claim_text"]
    ev_sents = [evidence[eid] for eid in obj["evidences"]]
    label = classify_claim(claim, ev_sents)
    out[cid] = {"claim_label": label, "evidences": obj["evidences"]}


with open("../data/test-predicted-mistral.json", "w") as f:
    json.dump(out, f, indent=2)
print("✓ ")

In [None]:
from collections import Counter

label_counts = Counter([entry["claim_label"] for entry in out.values()])

print("🔢 Label counts:")
for label, count in label_counts.items():
    print(f"{label:>20}: {count}")