In [1]:
import argparse
import copy
import json
import logging
import math
import os
import pickle
import re
import numpy
import nltk
import torch

from scipy import spatial
from sentence_transformers import SentenceTransformer

from retriever.dense_retriever import DenseRetriever
from retriever.sparse_retriever_fast import SparseRetrieverFast

In [2]:
nltk.download('punkt')
logging.getLogger().setLevel(logging.INFO)

[nltk_data] Downloading package punkt to /home/yic055/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


## MUSER Code Structure
#### ref: https://github.com/Complex-data/MUSER

In [17]:
class NewsRetriever:
    def __init__(self, docs_file=None, index_path='index', models_path='models/weights', encoder_batch_size=32):
        # Initialization code as in the script provided
        self.index_path = index_path
        self.encoder_batch_size = encoder_batch_size

        device = 'cpu'
        if torch.cuda.is_available():
            device = 'cuda'

        # initialize the sentence tokenizer
        self.sent_tokenizer = nltk.data.load('tokenizers/punkt/english.pickle')
        self.sent_tokenizer._params.abbrev_types.update(['e.g', 'i.e', 'subsp'])

        # initialize the passage embedding model
        self.text_embedding_model = SentenceTransformer('all-mpnet-base-v2',
                                                        device=device)

        if docs_file is None:
            if os.path.exists('{}/vectors.pkl'.format(self.index_path)):
                self.dense_index = DenseRetriever(model=self.text_embedding_model, batch_size=32)
                self.dense_index.create_index_from_vectors('{}/vectors.pkl'.format(index_path))
                self.sparse_index = SparseRetrieverFast(path=self.index_path)
                self.documents = pickle.load(open('{}/documents.pkl'.format(index_path), 'rb'))

        else:
            self.index_documents(docs_file=docs_file)

    def index_documents(self, docs_file, sentences_per_snippet=5):
        # Code for indexing documents
        logging.info('Indexing snippets...')

        self.documents = {}
        all_snippets = []
        with open(docs_file, encoding='utf-8') as f:
            for i, line in enumerate(f):
                document = json.loads(line.rstrip('\n'))
                snippets = self.extract_snippets(document["text"], sentences_per_snippet)
                for snippet in snippets:
                    all_snippets.append(snippet)
                    self.documents[len(self.documents)] = {
                        'snippet': snippet
                    }
                if i % 1000 == 0:
                    logging.info('processed: {} - snippets: {}'.format(i, len(all_snippets)))

        # Check if the index_path directory exists, create it if not
        if not os.path.exists(self.index_path):
            os.makedirs(self.index_path)

        pickle.dump(self.documents, open('{}/documents.pkl'.format(self.index_path), 'wb'))

        logging.info('Building sparse index...')

        self.sparse_index = SparseRetrieverFast(path=self.index_path) 
        self.sparse_index.index_documents(all_snippets)

        logging.info('Building dense index...')

        self.dense_index = DenseRetriever(model=self.text_embedding_model,
                                          batch_size=self.encoder_batch_size)
        self.dense_index.create_index_from_documents(all_snippets)
        self.dense_index.save_index(vectors_path='{}/vectors.pkl'.format(self.index_path))

        logging.info('Done')

    def extract_snippets(self, text, sentences_per_snippet=5):
        # Code for extracting snippets from text
        sentences = self.sent_tokenizer.tokenize(text)
        snippets = []
        i = 0
        last_index = 0
        while i < len(sentences):
            snippet = ' '.join(sentences[i:i + sentences_per_snippet])
            if len(snippet.split(' ')) > 4:
                snippets.append(snippet)
            last_index = i + sentences_per_snippet
            i += int(math.ceil(sentences_per_snippet / 2))
        if last_index < len(sentences):
            snippet = ' '.join(sentences[last_index:])
            if len(snippet.split(' ')) > 4:
                snippets.append(snippet)
        return snippets

    def search(self, query, limit=100):
        # Code for performing a search
        logging.info('Running sparse retriever for: {}'.format(query))

        sparse_results = self.sparse_index.search([query], topk=limit)[0]
        sparse_results = [r[0] for r in sparse_results]

        logging.info('Running dense retriever for: {}'.format(query))

        dense_results = self.dense_index.search([query], limit=limit)[0]
        dense_results = [r[0] for r in dense_results]

        results = list(set(sparse_results + dense_results))

        # print(sparse_results)
        # print(len(self.documents))
        search_results = []
        if len(results) > 0:
#             for i in range(len(results)):
#                 doc_id = results[i]
#                 result = copy.copy(self.documents[doc_id])
#                 search_results.append(result)
            
            for i in range(len(results)):
                doc_id = results[i]
                try:
                    result = copy.copy(self.documents[doc_id])
                    search_results.append(result)
                except KeyError as e:
                    logging.warning(f"Document ID {doc_id} not found in `self.documents`.")
                    
        paragraphs = search_results.copy()
        logging.info('highlighting...')
        results_sentences = []
        sentences_texts = []
        sentences_vectors = {}
        for i, r in enumerate(search_results):
            sentences = self.sent_tokenizer.tokenize(r['snippet'])
            sentences = [s for s in sentences if len(s.split(' ')) > 4]
            sentences_texts.extend(sentences)
            results_sentences.append(sentences)

        vectors = self.text_embedding_model.encode(sentences=sentences_texts, batch_size=128)
        for i, v in enumerate(vectors):
            sentences_vectors[sentences_texts[i]] = v

        query_vector = self.text_embedding_model.encode(sentences=[query], batch_size=1)[0]
        for i, sentences in enumerate(results_sentences):
            best_sentences = set()
            evidence_sentences = []
            for sentence in sentences:
                sentence_vector = sentences_vectors[sentence]
                score = 1 - spatial.distance.cosine(query_vector, sentence_vector)
                if score > 0.9:
                    best_sentences.add(sentence)
                    evidence_sentences.append(sentence)
            if len(evidence_sentences) > 0:
                search_results[i]['evidence'] = ' '.join(evidence_sentences)
            search_results[i]['snippet'] = \
                ' '.join([s if s not in best_sentences else '<b>{}</b>'.format(s) for s in sentences])

        search_results = [s for s in search_results if 'evidence' in s]

        search_results = search_results[:limit]
        paragraphs = paragraphs[:limit]
        logging.info('done searching')
        return search_results,paragraphs

## Initiate Instance

In [18]:
docs_file = 'data/polusa2019.jsonl'  # Example path to documents file
index_path = 'index'
models_path = 'models/weights'
encoder_batch_size = 32
limit = 10

# Instantiate the NewsRetriever
q = NewsRetriever(docs_file=docs_file,
                  index_path=index_path,
                  models_path=models_path,
                  encoder_batch_size=encoder_batch_size)

INFO:sentence_transformers.SentenceTransformer:Load pretrained SentenceTransformer: all-mpnet-base-v2
INFO:root:Indexing snippets...
INFO:root:processed: 0 - snippets: 1
INFO:root:processed: 1000 - snippets: 8375
INFO:root:processed: 2000 - snippets: 15853
INFO:root:processed: 3000 - snippets: 23321
INFO:root:processed: 4000 - snippets: 31179
INFO:root:processed: 5000 - snippets: 38908
INFO:root:processed: 6000 - snippets: 46622
INFO:root:processed: 7000 - snippets: 53578
INFO:root:processed: 8000 - snippets: 61330
INFO:root:processed: 9000 - snippets: 68736
INFO:root:processed: 10000 - snippets: 77169
INFO:root:processed: 11000 - snippets: 84716
INFO:root:processed: 12000 - snippets: 92289
INFO:root:processed: 13000 - snippets: 100265
INFO:root:processed: 14000 - snippets: 107924
INFO:root:processed: 15000 - snippets: 115464
INFO:root:processed: 16000 - snippets: 123017
INFO:root:processed: 17000 - snippets: 131812
INFO:root:processed: 18000 - snippets: 138868
INFO:root:processed: 190

Batches:   0%|          | 0/4827 [00:00<?, ?it/s]

INFO:root:Indexing 154446 vectors
INFO:root:Using 2896 centroids
INFO:root:Training index...
INFO:root:Adding vectors to index...
INFO:root:Built index
INFO:root:Done


## Summarized Evidence (Limited)

In [15]:
import torch
from transformers import pipeline

# Function to summarize evidence for a single claim
def summarize_evidence_for_claim(claim, q, summarizer):
    # Perform search to retrieve top 5 pieces of evidence
    _, search_results = q.search(claim, limit=5)

    # Extract the evidence snippets
    evidence_texts = [result['snippet'] for result in search_results]

    # Join the evidence snippets into one text block
    evidence_text_block = "\n".join(evidence_texts)

    # Generate the summary for the evidence
    summary = summarizer(evidence_text_block, max_length=200, min_length=100, do_sample=False)

    # Return the summarized evidence
    return summary[0]['summary_text']

# Initialize a summarization pipeline
summarizer = pipeline("summarization", model="facebook/bart-large-cnn")

# Read all claims from the file
with open('data/liar_claims.txt', 'r') as file:
    claims = file.readlines()

# Open the output file for writing
with open('liar_claims_evidence_summarized.txt', 'w') as outfile:
    for claim in claims:
        claim = claim.strip()
        if not claim:  # Skip empty lines
            continue

        try:
            # Summarize evidence for the current claim
            summarized_evidence = summarize_evidence_for_claim(claim, q, summarizer)
            
            # Write the claim and its summarized evidence to the output file
            outfile.write(f"Claim: {claim}\n")
            outfile.write(f"Evidence Summary: {summarized_evidence}\n\n")
        except Exception as e:
            print(f"Error processing claim: {claim}\n{e}\n")

print("Processing complete. Summarized evidence written to liar_claims_evidence.txt.")