<a href="https://colab.research.google.com/github/thiagolaitz/IA368-search-engines/blob/main/Project%2007/splade.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# SParse Lexical AnD Expansion model

SPLADE is a deep learning model designed to learn the sparse representation of tokens. In this notebook, we will be using SPLADE to create an inverted index , similar to that used in the BM25 algorithm, utilizing the model's scores and evaluating its effectiveness for search on the TREC-COVID dataset.

In [None]:
!pip install transformers -q

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.0/7.0 MB[0m [31m26.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m224.5/224.5 kB[0m [31m12.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.8/7.8 MB[0m [31m74.5 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
import json
import time
from collections import defaultdict

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from tqdm.auto import tqdm

from transformers import AutoModelForMaskedLM, AutoTokenizer

# Model

In [None]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"

model = AutoModelForMaskedLM.from_pretrained("naver/splade_v2_distil").to(device)
tokenizer = AutoTokenizer.from_pretrained("naver/splade_v2_distil")

Downloading (…)lve/main/config.json:   0%|          | 0.00/523 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/268M [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/258 [00:00<?, ?B/s]

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

# Dataset download

The TREC-COVID-2020 dataset is a large-scale information retrieval dataset that was released in response to the COVID-19 pandemic. The dataset was created to facilitate research into the use of information retrieval systems for helping with COVID-19 related tasks, such as finding relevant research articles and answering questions related to the pandemic. The TREC-COVID-2020 dataset will be used to evaluate the performance of our model.

In [None]:
# Corpus with all passages
!wget https://huggingface.co/datasets/BeIR/trec-covid/resolve/main/corpus.jsonl.gz
!gzip -cd corpus.jsonl.gz > corpus.jsonl

--2023-04-26 11:54:27--  https://huggingface.co/datasets/BeIR/trec-covid/resolve/main/corpus.jsonl.gz
Resolving huggingface.co (huggingface.co)... 18.155.68.44, 18.155.68.38, 18.155.68.121, ...
Connecting to huggingface.co (huggingface.co)|18.155.68.44|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs.huggingface.co/repos/a8/10/a810e88b0e7b233be82b89c1fa6ec2d75efc6d55784c2ada9dcac8434a634f3a/e9e97686e3138eaff989f67c04cd32e8f8f4c0d4857187e3f180275b23e24e85?response-content-disposition=attachment%3B+filename*%3DUTF-8%27%27corpus.jsonl.gz%3B+filename%3D%22corpus.jsonl.gz%22%3B&response-content-type=application%2Fgzip&Expires=1682767274&Policy=eyJTdGF0ZW1lbnQiOlt7IlJlc291cmNlIjoiaHR0cHM6Ly9jZG4tbGZzLmh1Z2dpbmdmYWNlLmNvL3JlcG9zL2E4LzEwL2E4MTBlODhiMGU3YjIzM2JlODJiODljMWZhNmVjMmQ3NWVmYzZkNTU3ODRjMmFkYTlkY2FjODQzNGE2MzRmM2EvZTllOTc2ODZlMzEzOGVhZmY5ODlmNjdjMDRjZDMyZThmOGY0YzBkNDg1NzE4N2UzZjE4MDI3NWIyM2UyNGU4NT9yZXNwb25zZS1jb250ZW50LWRpc3Bvc2l0aW9uPSo

In [None]:
!head corpus.jsonl

{"_id": "ug7v899j", "title": "Clinical features of culture-proven Mycoplasma pneumoniae infections at King Abdulaziz University Hospital, Jeddah, Saudi Arabia", "text": "OBJECTIVE: This retrospective chart review describes the epidemiology and clinical features of 40 patients with culture-proven Mycoplasma pneumoniae infections at King Abdulaziz University Hospital, Jeddah, Saudi Arabia. METHODS: Patients with positive M. pneumoniae cultures from respiratory specimens from January 1997 through December 1998 were identified through the Microbiology records. Charts of patients were reviewed. RESULTS: 40 patients were identified, 33 (82.5%) of whom required admission. Most infections (92.5%) were community-acquired. The infection affected all age groups but was most common in infants (32.5%) and pre-school children (22.5%). It occurred year-round but was most common in the fall (35%) and spring (30%). More than three-quarters of patients (77.5%) had comorbidities. Twenty-four isolates (60

In [None]:
class CustomDataset(Dataset):
    """
    PyTorch dataset for processing text data in a specific format.

    Args:
        data_path (str): Path to the file containing the input data.
        tokenizer: Tokenizer object from the transformers library for encoding the text.
    """

    def __init__(self, data_path: str, tokenizer):
        self.tokenizer = tokenizer
        self.load_data(data_path)

    def load_data(self, data_path):
        """
        Reads the input file and stores the data in the `data` attribute as a list of tuples.
        Each tuple contains the query and the relevant document from a line of the file.
        Args:
            data_path (str): Path to the input file.
        """
        self.data = []
        with open(data_path, "r") as fin:
            for line in fin:
                doc = json.loads(line)
                content = f"{doc['title']}. {doc['text']}"
                self.data.append((doc["_id"], content))

    def __len__(self):
        """
        Returns the number of data samples in the dataset.
        Returns:
            int: Number of data samples in the dataset.
        """
        return len(self.data)

    def __getitem__(self, index):
        """
        Encodes a data sample at the given index using the tokenizer and returns it as a dictionary
        with the keys 'input_ids' and 'labels'.
        Args:
            index (int): Index of the data sample to encode.
        Returns:
            dict: Dictionary containing the encodings and id of the given data sample
        """
        encoding = self.tokenizer(
            self.data[index][1],
            add_special_tokens=True,
            return_special_tokens_mask=True,
            max_length=256,
            truncation=True,
            padding='max_length',
            return_tensors='pt'
        )
        return {
            'id': self.data[index][0],
            'encoding': encoding,
        }

In [None]:
corpus_dataset = CustomDataset("corpus.jsonl", tokenizer)
corpus_dataloader = DataLoader(corpus_dataset, batch_size=64)

# Inverted Index

To implement the inverted index, we use a dictionary in Python, where the tokens serve as the keys, and the values are lists of tuples containing the document ID and the corresponding token score in that document.

In [None]:
inverted_index = defaultdict(list)

with torch.no_grad():
    for batch in tqdm(corpus_dataloader, desc="Generating inverted index"):
        encodings = batch["encoding"].to(device)
        # Gets the model logits [batch x 256 x vocab size]
        logits = model(
            input_ids=encodings["input_ids"].squeeze(),
            attention_mask=encodings["attention_mask"].squeeze()
        ).logits

        # Removes the influence of special tokens
        filtered_logits = logits * torch.abs(1 - encodings["special_tokens_mask"]).squeeze().unsqueeze(-1)

        # Computes the scores for each token and document [batch x 256] 
        scores = torch.max(torch.log(1 + nn.functional.relu(filtered_logits)), dim=1).values
        
        # Gets the non-zero indices for each document 
        # Since the majority of scores are zero 
        # (SPLADE has decided they are not relevant for that document)
        non_zero_indices = []
        for doc in scores:
            non_zero_indices.append(torch.nonzero(doc).squeeze())
        
        # adds the tokens and scores to the inverted index
        for doc, indices in enumerate(non_zero_indices):
            if indices.ndim > 0:
                for index in indices:
                    # Get the token from the vocab
                    token = tokenizer.convert_ids_to_tokens(index.item(), skip_special_tokens=True)
                    # Add the token and score to the inverted index
                    inverted_index[token].append((batch["id"][doc], scores[doc][index].item()))

Generating inverted index:   0%|          | 0/2678 [00:00<?, ?it/s]

# Qrels and topics

In [None]:
!wget https://huggingface.co/datasets/BeIR/trec-covid/resolve/main/queries.jsonl.gz
!gzip -dc queries.jsonl.gz > queries.jsonl

--2023-04-26 13:10:57--  https://huggingface.co/datasets/BeIR/trec-covid/resolve/main/queries.jsonl.gz
Resolving huggingface.co (huggingface.co)... 18.155.68.44, 18.155.68.38, 18.155.68.121, ...
Connecting to huggingface.co (huggingface.co)|18.155.68.44|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs.huggingface.co/repos/a8/10/a810e88b0e7b233be82b89c1fa6ec2d75efc6d55784c2ada9dcac8434a634f3a/9eadcc2cdf140addc9dae83648bb2c6611f5e4b66eaed7475fa5a0ca48eda371?response-content-disposition=attachment%3B+filename*%3DUTF-8%27%27queries.jsonl.gz%3B+filename%3D%22queries.jsonl.gz%22%3B&response-content-type=application%2Fgzip&Expires=1682772110&Policy=eyJTdGF0ZW1lbnQiOlt7IlJlc291cmNlIjoiaHR0cHM6Ly9jZG4tbGZzLmh1Z2dpbmdmYWNlLmNvL3JlcG9zL2E4LzEwL2E4MTBlODhiMGU3YjIzM2JlODJiODljMWZhNmVjMmQ3NWVmYzZkNTU3ODRjMmFkYTlkY2FjODQzNGE2MzRmM2EvOWVhZGNjMmNkZjE0MGFkZGM5ZGFlODM2NDhiYjJjNjYxMWY1ZTRiNjZlYWVkNzQ3NWZhNWEwY2E0OGVkYTM3MT9yZXNwb25zZS1jb250ZW50LWRpc3Bvc2l0aW9u

In [None]:
topics = []

with open("queries.jsonl", "r") as fin:
    for line in fin:
        query = json.loads(line)
        topics.append((query["_id"], query["text"]))

In [None]:
# Qrels with all relevances
!wget https://huggingface.co/datasets/BeIR/trec-covid-qrels/raw/main/test.tsv

--2023-04-26 13:10:58--  https://huggingface.co/datasets/BeIR/trec-covid-qrels/raw/main/test.tsv
Resolving huggingface.co (huggingface.co)... 18.155.68.44, 18.155.68.38, 18.155.68.121, ...
Connecting to huggingface.co (huggingface.co)|18.155.68.44|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 980831 (958K) [text/plain]
Saving to: ‘test.tsv’


2023-04-26 13:10:58 (2.21 MB/s) - ‘test.tsv’ saved [980831/980831]



In [None]:
with open("qrels.tsv", "w") as fout:
    with open("test.tsv", "r") as fin:
        for idx, line in enumerate(fin):
            if idx != 0:
                qid, doc_id, relevance = line.strip().split("\t")
                fout.write(f"{qid}\t0\t{doc_id}\t{relevance}\n")

# TREC Run

This function is responsible for creating a RUN file in the TREC format:

QID Q0 DOC_ID RANK SCORE LABEL

In [None]:
import nltk
nltk.download('stopwords')
from nltk.corpus import stopwords

stopwords_list = stopwords.words('english')

[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Unzipping corpora/stopwords.zip.


In [None]:
def get_scores(query: str, top_k: int = 1000):
    # Tokenize the query removing stopwords
    filtered_query = [word for word in query.split() if word not in stopwords_list]
    encoding = tokenizer(
        " ".join(filtered_query),
        add_special_tokens=False,
        max_length=256,
        truncation=True,
        padding='do_not_pad',
        return_tensors='pt'
    )
    tokens = tokenizer.convert_ids_to_tokens(encoding["input_ids"].squeeze())

    # Get the doc scores by summing all scores found in the inverted index
    scores = defaultdict(int)
    for token in tokens:
        for doc_id, score in inverted_index[token]:
            scores[doc_id] += score
    sorted_scores = sorted(scores.items(), key=lambda x: x[1], reverse=True)
    return sorted_scores[:top_k]

def get_run(path: str, topics: list, top_k: int):
    """
    This function writes a TREC run in the given path.
    Args:
        path: the path to save the run
        topics: a list with queries for evaluating
        top_k: number of passages to retrieve
    """
    with open(path, 'w') as fout:
        for id, query in tqdm(topics, desc="Running queries"):
            scores = get_scores(query, top_k)
            rank = 1
            for doc_id, score in scores:
                fout.write(f"{id}\tQ0\t{doc_id}\t{rank}\t{score}\tSplade\n")
                rank += 1

In [None]:
get_run("run_splade.tsv", topics, 1000)

Running queries:   0%|          | 0/50 [00:00<?, ?it/s]

# Results

In [None]:
!pip install pyserini -q
!pip install faiss-cpu==1.7.2 -q

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m154.1/154.1 MB[0m [31m6.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.5/1.5 MB[0m [31m72.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.3/13.3 MB[0m [31m69.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.0/5.0 MB[0m [31m105.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m77.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m188.5/188.5 kB[0m [31m22.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m46.0/46.0 kB[0m [31m5.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m86.8/86.8 kB[0m [31m11.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━

In [None]:
!python -m pyserini.eval.trec_eval -c -m ndcg_cut.10 -mmap -l 2 qrels.tsv run_splade.tsv

Downloading https://search.maven.org/remotecontent?filepath=uk/ac/gla/dcs/terrierteam/jtreceval/0.0.5/jtreceval-0.0.5-jar-with-dependencies.jar to /root/.cache/pyserini/eval/jtreceval-0.0.5-jar-with-dependencies.jar...
jtreceval-0.0.5-jar-with-dependencies.jar: 1.79MB [00:02, 757kB/s]                 
Running command: ['java', '-jar', '/root/.cache/pyserini/eval/jtreceval-0.0.5-jar-with-dependencies.jar', '-c', '-m', 'ndcg_cut.10', '-mmap', '-l', '2', 'qrels.tsv', 'run_splade.tsv']
Results:
map                   	all	0.1663
ndcg_cut_10           	all	0.6369
