In [None]:
!pip -q install transformers sentence-transformers trafilatura tqdm regex

import os, re, json, time, math, requests, concurrent.futures
from dataclasses import dataclass
from typing import List, Dict, Any, Optional, Tuple
from tqdm import tqdm

import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from sentence_transformers import SentenceTransformer, util as st_util
import trafilatura

# Fast, deterministic-ish
import random, numpy as np
def set_seed(s=42):
    random.seed(s); np.random.seed(s); torch.manual_seed(s)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(s)
set_seed(42)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
AMP_DTYPE = (torch.bfloat16 if (DEVICE.type=="cuda" and torch.cuda.is_bf16_supported()) else torch.float16)
print("Device:", DEVICE, "AMP:", AMP_DTYPE)


In [None]:
# <<< SET THESE IF YOU HAVE THEM >>>
SERPAPI_KEY = os.getenv("SERPAPI_KEY", "")      # https://serpapi.com/
TAVILY_KEY  = os.getenv("TAVILY_API_KEY", "")   # https://tavily.com/
VOYAGE_KEY  = os.getenv("VOYAGE_API_KEY", "")   # if you use Voyage

# Source preferences (keeps quality high)
WHITELIST = [
    # your desired news channels domain
]

# Retrieval & ranking knobs (good defaults)
MAX_RESULTS = 20
BI_MIN_COS  = 0.35
TOP_K_AFTER_RERANK = 6

# NLI thresholds (balanced)
ENTAIL_THR = 0.65
CONTRA_THR = 0.55

# Timeouts
HTTP_TIMEOUT = 15


In [None]:
def search_serpapi(q: str, n=MAX_RESULTS) -> List[Dict[str, Any]]:
    if not SERPAPI_KEY: return []
    url = "https://serpapi.com/search.json"
    params = {
        "engine": "google",
        "q": q,
        "num": min(n, 20),
        "tbm": "nws",  # Google News
        "api_key": SERPAPI_KEY
    }
    try:
        r = requests.get(url, params=params, timeout=HTTP_TIMEOUT); r.raise_for_status()
        data = r.json()
        items = []
        for it in data.get("news_results", []):
            link = it.get("link") or ""
            if not link: continue
            items.append({
                "source": it.get("source",""),
                "title": it.get("title",""),
                "snippet": it.get("snippet",""),
                "date": it.get("date",""),
                "url": link
            })
        return items
    except Exception as e:
        print("SerpAPI error:", e)
        return []

def search_tavily(q: str, n=MAX_RESULTS) -> List[Dict[str, Any]]:
    if not TAVILY_KEY: return []
    try:
        r = requests.post("https://api.tavily.com/search",
            json={"api_key": TAVILY_KEY, "query": q, "max_results": n, "search_depth":"basic"},
            timeout=HTTP_TIMEOUT)
        r.raise_for_status()
        data = r.json()
        items = []
        for it in data.get("results", []):
            items.append({
                "source": it.get("source",""),
                "title": it.get("title",""),
                "snippet": it.get("content","")[:300],
                "date": it.get("published_date",""),
                "url": it.get("url","")
            })
        return items
    except Exception as e:
        print("Tavily error:", e)
        return []

def combined_search(q: str) -> List[Dict[str, Any]]:
    # Merge, whitelist, dedup by URL host+title
    items = search_serpapi(q) + search_tavily(q)
    def host(u):
        m = re.search(r"https?://([^/]+)/", u+"/")
        return m.group(1) if m else ""
    seen = set()
    out = []
    for it in items:
        h = host(it.get("url",""))
        if WHITELIST and h and not any(w in h for w in WHITELIST):
            continue
        key = (h, it.get("title","").strip().lower())
        if key in seen: continue
        seen.add(key); out.append(it)
    return out[:MAX_RESULTS]

def fetch_html(url: str) -> str:
    try:
        r = requests.get(url, timeout=HTTP_TIMEOUT, headers={"User-Agent":"Mozilla/5.0"})
        r.raise_for_status()
        return r.text
    except Exception:
        return ""

def extract_text(url: str) -> str:
    html = fetch_html(url)
    if not html: return ""
    txt = trafilatura.extract(html, include_comments=False, include_tables=False) or ""
    # light cleanup
    return re.sub(r"\s+", " ", txt).strip()

def enrich_with_body(items: List[Dict[str,Any]]) -> List[Dict[str,Any]]:
    # Parallel fetch bodies for speed
    def job(it):
        body = extract_text(it["url"])
        it2 = dict(it); it2["body"] = body
        return it2
    out = []
    with concurrent.futures.ThreadPoolExecutor(max_workers=8) as ex:
        for it2 in tqdm(ex.map(job, items), total=len(items), desc="Downloading articles"):
            out.append(it2)
    # drop empties
    return [x for x in out if x.get("body")]


In [None]:
# Bi-encoder (fast cosine)
bi = SentenceTransformer("sentence-transformers/multi-qa-MiniLM-L6-cos-v1", device="cuda" if DEVICE.type=="cuda" else "cpu")

# Cross-encoder (pairwise rerank)
ce_name = "cross-encoder/ms-marco-MiniLM-L-6-v2"
ce_tok  = AutoTokenizer.from_pretrained(ce_name)
ce      = AutoModelForSequenceClassification.from_pretrained(ce_name).to(DEVICE).eval()


def select_top_candidates(claim: str, docs: List[Dict[str,Any]], k=TOP_K_AFTER_RERANK) -> List[Dict[str,Any]]:
    if not docs: return []
    # Compose a short “evidence text” per doc: title + 2–3 sentences from body
    def lead(body: str, sent_n=3):
        sents = re.split(r'(?<=[.!?])\s+', body.strip())
        return " ".join(sents[:sent_n])
    cand_texts = []
    for d in docs:
        t = (d.get("title","") or d.get("headline","")).strip()
        b = lead(d.get("body",""))
        text = (t + ". " + b).strip()
        cand_texts.append(text if text else t)

    # 1) Bi-encoder cosine filter
    q = bi.encode([claim], convert_to_tensor=True, normalize_embeddings=True)
    em = bi.encode(cand_texts, convert_to_tensor=True, normalize_embeddings=True)
    sims = st_util.cos_sim(q, em).cpu().numpy().ravel()
    kept = [(i, sims[i]) for i in range(len(docs)) if sims[i] >= BI_MIN_COS]
    if not kept:
        # keep top-6 anyway to avoid empty
        kept = sorted([(i, sims[i]) for i in range(len(docs))], key=lambda x: -x[1])[:min(6, len(docs))]

    # 2) Cross-encoder rerank
    kept_idx = [i for i,_ in kept]
    pairs = [(claim, cand_texts[i]) for i in kept_idx]
    scores = []
    with torch.no_grad():
        for s in range(0, len(pairs), 16):
            a,b = zip(*pairs[s:s+16])
            enc = ce_tok(list(a), list(b), truncation=True, max_length=384, padding=True, return_tensors="pt").to(DEVICE)
            logits = ce(**enc).logits.squeeze(-1).detach().cpu().numpy().tolist()
            scores.extend(logits)
    order = sorted(range(len(kept_idx)), key=lambda j: -scores[j])
    top = [docs[kept_idx[j]] for j in order[:k]]
    for j in range(len(top)):
        top[j]["rerank_score"] = float(scores[order[j]])
    return top


In [None]:
NLI_NAME = "MoritzLaurer/DeBERTa-v3-large-mnli-fever-anli-ling-wanli"

nli_tok  = AutoTokenizer.from_pretrained(NLI_NAME)
nli      = AutoModelForSequenceClassification.from_pretrained(NLI_NAME).to(DEVICE).eval()
print("NLI labels:", nli.config.id2label)  # should be {0: 'CONTRADICTION', 1: 'NEUTRAL', 2: 'ENTAILMENT'}
# Map by label *names*, not indices
lbl_map = {v.lower(): int(k) for k, v in nli.config.id2label.items()}

def _get_idx(name_opts):
    for name in name_opts:
        name = name.lower()
        if name in lbl_map:
            return lbl_map[name]
    # fallback: assume order [C, N, E]
    return {"C":0, "N":1, "E":2}

C_IDX = _get_idx(["contradiction", "label_0", "contradict", "c"])
N_IDX = _get_idx(["neutral", "label_1", "n"])
E_IDX = _get_idx(["entailment", "label_2", "entails", "e"])

print("Resolved label indices ->", {"C":C_IDX, "N":N_IDX, "E":E_IDX})

def yield_chunks(text: str, max_prem_toks=420):
    toks = nli_tok.encode(text, add_special_tokens=False)
    for i in range(0, len(toks), max_prem_toks):
        chunk_ids = toks[i:i+max_prem_toks]
        yield nli_tok.decode(chunk_ids, skip_special_tokens=True)

@torch.no_grad()
def nli_best_chunk(premise_text: str, claim: str) -> Tuple[float,float,float,str]:
    best = (-1.0, -1.0, -1.0, "")
    for ch in yield_chunks(premise_text):
        enc = nli_tok(ch, claim, truncation=True, max_length=512, padding=False, return_tensors="pt").to(DEVICE)
        with torch.autocast(device_type=DEVICE.type, dtype=AMP_DTYPE) if DEVICE.type=="cuda" else torch.no_grad():
            out = nli(**enc)
        probs = out.logits.softmax(-1).squeeze().detach().cpu().numpy().tolist()  # [C,N,E]
        C,N,E = probs
        if E > best[2]:
            best = (C,N,E,ch)
    return best  # (C, N, E, chunk)

def verify_claim(claim: str, docs: List[Dict[str,Any]]) -> Dict[str,Any]:
    # Build premise as: title + first 3–4 sentences + first long paragraph
    results = []
    for d in docs:
        title = (d.get("title","") or d.get("headline","")).strip()
        body  = d.get("body","")
        # compose premise
        sents = re.split(r'(?<=[.!?])\s+', body.strip())
        premise = (title + ". " + " ".join(sents[:4]) + " " + (" ".join(sents[4:10]) if len(sents)>6 else "")).strip()
        C,N,E,ch = nli_best_chunk(premise, claim)
        results.append({"doc": d, "C": C, "N": N, "E": E, "chunk": ch})

    if not results:
        return {"label":"NO_EVIDENCE", "reason":"no documents"}

    best_ent = max(results, key=lambda r: r["E"])
    best_con = max(results, key=lambda r: r["C"])

    if best_ent["E"] >= ENTAIL_THR:
        return {"label":"SUPPORTED", "confidence": float(best_ent["E"]), "evidence": best_ent}
    if best_con["C"] >= CONTRA_THR:
        return {"label":"REFUTED", "confidence": float(best_con["C"]), "evidence": best_con}

    return {"label":"NO_EVIDENCE", "reason":"all neutral or low-confidence",
            "best_entail": float(best_ent["E"]), "best_contra": float(best_con["C"])}


In [None]:
# ===== Cell 5 — NLI (robust label mapping + safe chunking) =====
from typing import Tuple
import torch, re
from transformers import AutoTokenizer, AutoModelForSequenceClassification

NLI_NAME = "MoritzLaurer/DeBERTa-v3-large-mnli-fever-anli-ling-wanli"  # ok

nli_tok  = AutoTokenizer.from_pretrained(NLI_NAME)
nli      = AutoModelForSequenceClassification.from_pretrained(NLI_NAME).to(DEVICE).eval()
print("NLI labels:", nli.config.id2label)

# 1) Map labels by NAME, not by index (robust across models)
# Map labels by NAME, not index
lbl_map = {v.lower(): int(k) for k, v in nli.config.id2label.items()}

def idx_for(name_opts, default_idx):
    for name in name_opts:
        name = name.lower()
        if name in lbl_map:
            return lbl_map[name]
    return default_idx  # fallback to common order [C,N,E] -> 0,1,2

C_IDX = idx_for(["contradiction", "label_0", "contradict", "c"], 0)
N_IDX = idx_for(["neutral", "label_1", "n"], 1)
E_IDX = idx_for(["entailment", "label_2", "entails", "e"], 2)

print("Resolved label indices ->", {"C": C_IDX, "N": N_IDX, "E": E_IDX})


# print("Resolved label indices ->", {"C":C_IDX, "N":N_IDX, "E":E_IDX})

# 2) Dynamic chunking that leaves room for the claim; truncate ONLY the premise
def yield_chunks_dynamic(premise_text: str, claim: str, max_total=512, safety=12):
    hyp_ids = nli_tok.encode(claim, add_special_tokens=False)
    max_prem = max(64, max_total - len(hyp_ids) - safety - 3)  # ~3 specials
    prem_ids = nli_tok.encode(premise_text, add_special_tokens=False)
    for i in range(0, len(prem_ids), max_prem):
        yield nli_tok.decode(prem_ids[i:i+max_prem], skip_special_tokens=True)

@torch.no_grad()
def nli_best_chunk(premise_text: str, claim: str) -> Tuple[float,float,float,str]:
    best = (-1.0, -1.0, -1.0, "")
    for ch in yield_chunks_dynamic(premise_text, claim, max_total=512, safety=12):
        enc = nli_tok(
            ch, claim,
            truncation="only_first",   # truncate premise only
            max_length=512,
            padding=False,
            return_tensors="pt"
        ).to(DEVICE)
        # Use AMP on GPU, but always cast logits to float32 before numpy()
        ctx = torch.autocast(device_type=DEVICE.type, dtype=AMP_DTYPE) if DEVICE.type=="cuda" else torch.no_grad()
        with ctx:
            out = nli(**enc)
        logits = out.logits.float()                      # avoid bf16/fp16 numpy issues
        probs  = torch.softmax(logits, dim=-1).squeeze(0)
        C = float(probs[C_IDX]); N = float(probs[N_IDX]); E = float(probs[E_IDX])
        if E > best[2]:
            best = (C, N, E, ch)
    return best  # (C, N, E, best_chunk_text)

def verify_claim(claim: str, docs: List[Dict[str,Any]]) -> Dict[str,Any]:
    # Build premise as: title + first 3–4 sentences + a bit more context
    results = []
    for d in docs:
        title = (d.get("title","") or d.get("headline","")).strip()
        body  = d.get("body","")
        sents = re.split(r'(?<=[.!?])\s+', body.strip())
        premise = (title + ". " + " ".join(sents[:4]) + " " + (" ".join(sents[4:10]) if len(sents)>6 else "")).strip()
        C,N,E,ch = nli_best_chunk(premise, claim)
        results.append({"doc": d, "C": C, "N": N, "E": E, "chunk": ch})

    if not results:
        return {"label":"NO_EVIDENCE", "reason":"no documents"}

    best_ent = max(results, key=lambda r: r["E"])
    best_con = max(results, key=lambda r: r["C"])

    if best_ent["E"] >= ENTAIL_THR:
        return {"label":"SUPPORTED", "confidence": float(best_ent["E"]), "evidence": best_ent}
    if best_con["C"] >= CONTRA_THR:
        return {"label":"REFUTED", "confidence": float(best_con["C"]), "evidence": best_con}

    return {
        "label":"NO_EVIDENCE",
        "reason":"all neutral or low-confidence",
        "best_entail": float(best_ent["E"]),
        "best_contra": float(best_con["C"])
    }


In [None]:
def fact_check(claim: str) -> Dict[str,Any]:
    print(f"\n=== CLAIM ===\n{claim}\n")

    # 1. Retrieve headlines/snippets
    hits = combined_search(claim)   # from Cell 3
    if not hits:
        return {"label":"NO_EVIDENCE", "reason":"search returned 0 items"}

    # 2. Fetch full article bodies
    docs = enrich_with_body(hits)   # from Cell 3
    if not docs:
        return {"label":"NO_EVIDENCE", "reason":"no article text extracted"}

    # 3. Rerank/filter candidates
    top_docs = select_top_candidates(claim, docs)   # from Cell 4

    # 4. Verify claim with NLI
    out = verify_claim(claim, top_docs)   # from Cell 5

    # 5. Pretty print evidence
    label = out["label"]
    print(f"\nRESULT: {label}")
    if label in ("SUPPORTED","REFUTED"):
        ev = out["evidence"]; d = ev["doc"]
        print(f"Confidence: {out['confidence']:.3f}")
        print(f"Source: {d.get('source','?')}\nTitle: {d.get('title','')}\nURL: {d.get('url','')}")
        print("\n--- Evidence excerpt ---")
        print((ev["chunk"] or "").strip()[:800])
    else:
        print("Reason:", out.get("reason"))
        print(f"Best entail={out.get('best_entail',0):.3f} | Best contra={out.get('best_contra',0):.3f}")

    return out


In [None]:
result = fact_check("")
print(result)
