In [1]:
! pip install -Uq datasets splade-index numba "jax[cpu]"

In [2]:
import os

os.environ["JAX_PLATFORMS"]="cpu"

### **Load Dataset**

In [None]:
from datasets import load_dataset

msmarco = load_dataset("rasyosef/msmarco")
msmarco

In [4]:
test_dataset = msmarco["dev"].shuffle(seed=42)#.select(range(10_000))
test_dataset

Dataset({
    features: ['query_id', 'query', 'positives', 'negatives'],
    num_rows: 55577
})

In [5]:
import hashlib
from datasets import concatenate_datasets

# Using md5 hash function to deduplicate documents
def md5(text):
  res = hashlib.md5(text.encode())
  return res.hexdigest()

dev_dataset = test_dataset

dev_queries = dict(zip(dev_dataset["query_id"], dev_dataset["query"]))

dev_corpus = {}
for row in dev_dataset:
  for passage in row["positives"]:
    dev_corpus[md5(passage)] = passage

  for passage in row["negatives"]:
    dev_corpus[md5(passage)] = passage

dev_relevant_docs = dict(
    zip(
      dev_dataset["query_id"],
      [[md5(pos) for pos in positives] for positives in dev_dataset["positives"]]
    )
  )

len(dev_corpus), len(dev_queries), len(dev_relevant_docs)

(542280, 55577, 55577)

### **Index the Documents**

In [None]:
from sentence_transformers import SparseEncoder

# Download a SPLADE model from the 🤗 Hub
splade = SparseEncoder("rasyosef/splade-mini", device="cuda")

splade

In [8]:
# The documents
corpus = list(dev_corpus.values())
len(corpus)

542280

In [None]:
from splade_index import SPLADE

# Create the SPLADE retriever and index the corpus
retriever = SPLADE()
retriever.index(model=splade, documents=corpus)

### **Query the index**

In [10]:
queries = list(dev_queries.values())[:5000]
len(queries), queries[:3]

(5000,
 ['is natural gas renewable',
  'how many hours of sunlight do succulents need?',
  'what is an acute / obtuse triangle'])

In [None]:
# Get top-k results as a tuple of (doc_ids, documents, scores). All three are arrays of shape (n_queries, k).
from time import time

start_time = time()

results = retriever.retrieve(queries, k=5)
doc_ids, result_docs, scores = results.doc_ids, results.documents, results.scores

time_taken = time() - start_time
print(f"Average retrieval time per query: {time_taken*1000/len(queries) :.2f} ms")

Batches: 100%|██████████| 157/157 [00:01<00:00, 102.26it/s]


SPLADE Index Retrieve:   0%|          | 0/5000 [00:00<?, ?it/s]

Avergae retrieval time per query: 6.11 ms


In [12]:
print("Query:", queries[0])

print("Retrieved Documents:")
for i in range(doc_ids.shape[1]):
    doc_id, doc, score = doc_ids[0, i], result_docs[0, i], scores[0, i]
    print(f"Rank {i+1} (score: {score:.2f}) (doc_idx: {doc_id}):", doc)

Query: is natural gas renewable
Retrieved Documents:
Rank 1 (score: 27.00) (doc_idx: 7): Renewable natural gas can be produced economically, and distributed via the existing gas grid, making it an attractive means of supplying existing premises with renewable heat and renewable gas energy, while requiring no extra capital outlay of the customer.
Rank 2 (score: 25.99) (doc_idx: 0): Renewable natural gas. Renewable natural gas, also known as sustainable natural gas, is a biogas which has been upgraded to a quality similar to fossil natural gas. A biogas is a gas methane obtained from biomass. By upgrading the quality to that of natural gas, it becomes possible to distribute the gas to customers via the existing gas grid, within existing appliances. Renewable natural gas is a subset of synthetic natural gas or substitute natural gas (SNG).
Rank 3 (score: 25.02) (doc_idx: 6): While conventional natural gas is not considered a renewable fuel, biomethane or renewable natural gas can be produ

### **Summary:**
- splade-tiny takes 9m to index 500k docs 5.94ms/query on a T4 16GB GPU
- splade-mini takes 12m to index 500k docs and 6.11ms/query on a T4 16GB GPU