In [40]:
from sentence_transformers import SentenceTransformer
import faiss, numpy as np, ujson as json
import os

In [41]:
with open('../data/train-claims.json') as f:
    train = json.load(f)
with open('../data/dev-claims.json') as f:
    dev = json.load(f)
with open('../data/evidence.json') as f:
    evidence = json.load(f)


In [42]:
MODEL = 'sentence-transformers/all-MiniLM-L6-v2'
EMBED_CACHE = "../data/evidence_vecs.npy"
ID_CACHE = "../data/evidence_ids.json"
FAISS_CACHE = "../data/evidence.faiss"

In [43]:
# encode the evidences if cached can't found, then cache it
if os.path.exists(FAISS_CACHE) and os.path.exists(ID_CACHE):
    index = faiss.read_index(FAISS_CACHE)
    with open(ID_CACHE) as f:
        ids = json.load(f)
else:
    ids = list(evidence.keys())
    docs = list(evidence.values())
    vecs = embedder.encode(docs,
                           batch_size=32,
                           convert_to_numpy=True,
                           normalize_embeddings=True)
    index = faiss.IndexFlatIP(vecs.shape[1])
    index.add(vecs)
    faiss.write_index(index, FAISS_CACHE)
    with open(ID_CACHE, "w") as f:
        json.dump(ids, f)

In [44]:
# retrieve top 100 evidence via FAISS embedding
def topk_ids(claim, k=100):
    v = embedder.encode(claim,
                        convert_to_numpy=True,
                        normalize_embeddings=True)
    _, I = index.search(v[None], k)
    return [ids[i] for i in I[0]]

In [45]:
from sentence_transformers import CrossEncoder

# Using a cross-encoder to rerank the 100 evidence retrieved
reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2', device='cuda')

def reranker_score(claim, evidence_text):
    return reranker.predict([(claim, evidence_text)])[0]


In [46]:
print("Retrieving and reranking top evidence for train claims...")
out = {}
for cid, itm in tqdm(train.items()):
    claim = itm["claim_text"]
    candidates = topk_ids(claim, k=100)

    pairs = [(claim, evidence[eid]) for eid in candidates]
    scores = reranker.predict(pairs)

    ranked = [eid for _, eid in sorted(zip(scores, candidates), reverse=True)]
    top6 = ranked[:6] # as explored, 6 evidence give best score

    out[cid] = {
        "claim_text": claim,
        "claim_label": "",
        "evidences": top6
    }


Retrieving and reranking top evidence for train claims...


100%|██████████| 1228/1228 [02:29<00:00,  8.21it/s]


In [47]:
with open("../data/train-evidence-faiss.json", "w") as f:
    json.dump(out, f, indent=2)

print("✓")

✓


In [50]:
print("Retrieving and reranking top evidence for dev claims...")
out = {}
for cid, itm in tqdm(dev.items()):
    claim = itm["claim_text"]
    candidates = topk_ids(claim, k=100)

    pairs = [(claim, evidence[eid]) for eid in candidates]
    scores = reranker.predict(pairs)

    ranked = [eid for _, eid in sorted(zip(scores, candidates), reverse=True)]
    top6 = ranked[:6]

    out[cid] = {
        "claim_text": claim,
        "claim_label": "",
        "evidences": top6
    }
with open("../data/dev-evidence-faiss.json", "w") as f:
    json.dump(out, f, indent=2)

print("Saved ✓")

Retrieving and reranking top evidence for dev claims...


100%|██████████| 154/154 [00:19<00:00,  8.07it/s]

Saved ✓





In [51]:
print("Retrieving and reranking top evidence for test claims...")
out = {}
for cid, itm in tqdm(train.items()):
    claim = itm["claim_text"]
    candidates = topk_ids(claim, k=100)

    pairs = [(claim, evidence[eid]) for eid in candidates]
    scores = reranker.predict(pairs)

    ranked = [eid for _, eid in sorted(zip(scores, candidates), reverse=True)]
    top6 = ranked[:6]

    out[cid] = {
        "claim_text": claim,
        "claim_label": "",
        "evidences": top6
    }
with open("../data/test-evidence-faiss.json", "w") as f:
    json.dump(out, f, indent=2)

print("Saved ✓")

Retrieving and reranking top evidence for test claims...


100%|██████████| 1228/1228 [02:31<00:00,  8.13it/s]

Saved ✓



