<a href="https://colab.research.google.com/github/saniyalakka19/neural-information-retrieval/blob/main/notebooks/SPLADE_Inference_with_Boolean_Filter.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# SPLADE: Sparse Lexical and Expansion Model for First Stage Ranking

This notebook gives a minimal example usage of SPLADE.

* In this repo, we provide weights for 2 models (in the `weights` folder)
* See [Naver Labs Europe website](https://europe.naverlabs.com/research/machine-learning-and-optimization/splade-models/) for more up-to-date models under various settings
* We also provide two new models via Hugging Face (https://huggingface.co/naver)

| model | MRR@10 (MS MARCO dev) | recall@1000 (MS MARCO dev) | expected FLOPS | ~ avg q length | ~ avg d length | 
| --- | --- | --- | --- | --- | --- |
| `splade_max` (**v2**) | 34.0 | 96.5 | 1.32 | 18 | 92 |
| `distilsplade_max` (**v2**) | 36.8 | 97.9 | 3.82 | 25 | 232 |
| `naver/splade-cocondenser-selfdistil` (**v2bis**, [HF](https://huggingface.co/naver/splade-cocondenser-selfdistil))| 37.6 | 98.4 | 2.32 | 56 | 134 |
| `naver/splade-cocondenser-ensembledistil` (**v2bis**, [HF](https://huggingface.co/naver/splade-cocondenser-ensembledistil)) | 38.3 | 98.3  | 1.85 | 44 | 120 |

In [None]:
!git clone https://github.com/naver/splade.git

Cloning into 'splade'...
remote: Enumerating objects: 514, done.[K
remote: Counting objects: 100% (47/47), done.[K
remote: Compressing objects: 100% (24/24), done.[K
remote: Total 514 (delta 31), reused 23 (delta 23), pack-reused 467[K
Receiving objects: 100% (514/514), 3.07 MiB | 17.66 MiB/s, done.
Resolving deltas: 100% (273/273), done.
Filtering content: 100% (2/2), 511.12 MiB | 28.54 MiB/s, done.


In [None]:
pip install ./splade

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Processing ./splade
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting transformers==4.18.0
  Downloading transformers-4.18.0-py3-none-any.whl (4.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.0/4.0 MB[0m [31m39.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting omegaconf==2.1.2
  Downloading omegaconf-2.1.2-py3-none-any.whl (74 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m74.7/74.7 KB[0m [31m1.7 MB/s[0m eta [36m0:00:00[0m
Collecting antlr4-python3-runtime==4.8
  Downloading antlr4-python3-runtime-4.8.tar.gz (112 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m112.4/112.4 KB[0m [31m7.8 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting tokenizers!=0.11.3,<0.13,>=0.11.1
  Downloading tokenizers-0.12.1-cp38-cp38-manylinux_2_12_x86_64.manylinux2010_x86_64.wh

In [None]:
import numpy as np

In [None]:
import torch
from transformers import AutoModelForMaskedLM, AutoTokenizer
from splade.models.transformer_rep import Splade

In [None]:
# set the dir for trained weights

##### v2
# model_type_or_dir = "weights/splade_max"
# model_type_or_dir = "weights/distilsplade_max"

### v2bis, directly download from Hugging Face
# model_type_or_dir = "naver/splade-cocondenser-selfdistil"
model_type_or_dir = "naver/splade-cocondenser-ensembledistil"


In [None]:
# loading model and tokenizer

model = Splade(model_type_or_dir, agg="max")
model.eval()
tokenizer = AutoTokenizer.from_pretrained(model_type_or_dir)
reverse_voc = {v: k for k, v in tokenizer.vocab.items()}

Downloading:   0%|          | 0.00/670 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/418M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/466 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/226k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/455k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/112 [00:00<?, ?B/s]

In [None]:
# example document from MS MARCO passage collection (doc_id = 8003157)
doc = """
    Glass and Thermal Stress. 
    Thermal Stress is created 
    when one area of a glass 
    pane gets hotter than an 
    adjacent area. If the stress 
    is too great then the glass 
    will crack. The stress level 
    at which the glass will break 
    is governed by several factors.
    """

In [None]:
# now compute the document representation
with torch.no_grad():
    m = model(
        d_kwargs=tokenizer(
            doc, return_tensors="pt"
            )
        )
    doc_rep = m["d_rep"].squeeze()
    # (sparse) doc rep in voc space, shape (30522,)

# get the number of non-zero dimensions in the rep:
col = torch.nonzero(doc_rep).squeeze().cpu().tolist()
print("number of actual dimensions: ", len(col))

# now let's inspect the bow representation:
weights = doc_rep[col].cpu().tolist()
d = {k: v for k, v in zip(col, weights)}
sorted_d = {k: v for k, v in sorted(d.items(), key=lambda item: item[1], reverse=True)}
bow_rep = []
for k, v in sorted_d.items():
    bow_rep.append((reverse_voc[k], round(v, 2)))
print("SPLADE BOW rep:\n", bow_rep)



number of actual dimensions:  126
SPLADE BOW rep:
 [('stress', 2.25), ('glass', 2.23), ('thermal', 2.18), ('glasses', 1.65), ('pan', 1.62), ('heat', 1.56), ('stressed', 1.42), ('crack', 1.31), ('break', 1.12), ('cracked', 1.1), ('hot', 0.93), ('created', 0.9), ('factors', 0.81), ('broken', 0.73), ('caused', 0.71), ('too', 0.71), ('damage', 0.69), ('if', 0.68), ('hotter', 0.65), ('governed', 0.61), ('heating', 0.59), ('temperature', 0.59), ('adjacent', 0.59), ('cause', 0.58), ('effect', 0.57), ('fracture', 0.56), ('bradford', 0.55), ('strain', 0.53), ('hammer', 0.51), ('brian', 0.48), ('error', 0.47), ('windows', 0.45), ('will', 0.45), ('reaction', 0.42), ('create', 0.42), ('windshield', 0.41), ('heated', 0.41), ('factor', 0.4), ('cracking', 0.39), ('failure', 0.38), ('mechanical', 0.38), ('when', 0.38), ('formed', 0.38), ('bolt', 0.38), ('mechanism', 0.37), ('warm', 0.37), ('areas', 0.36), ('area', 0.36), ('energy', 0.34), ('disorder', 0.33), ('barry', 0.33), ('shock', 0.32), ('determi

In [None]:
d

{2043: 0.3811924457550049,
 2063: 0.013468118384480476,
 2065: 0.679470956325531,
 2097: 0.44665437936782837,
 2131: 0.11370912939310074,
 2138: 0.024708406999707222,
 2181: 0.35583481192588806,
 2205: 0.7077890038490295,
 2307: 0.07926920056343079,
 2453: 0.18895027041435242,
 2504: 0.3077377378940582,
 2550: 0.1225956529378891,
 2580: 0.9020834565162659,
 2600: 0.11568041145801544,
 2719: 0.3802264332771301,
 2752: 0.36217206716537476,
 2943: 0.33946168422698975,
 2980: 0.9349669814109802,
 3103: 0.015637045726180077,
 3221: 2.2316195964813232,
 3239: 0.2019173800945282,
 3277: 0.24346952140331268,
 3291: 0.24687933921813965,
 3303: 0.7140827178955078,
 3332: 0.2980433404445648,
 3338: 1.119208812713623,
 3382: 0.1564858853816986,
 3399: 0.3090040981769562,
 3426: 0.579511284828186,
 3443: 0.4152419865131378,
 3466: 0.5718740820884705,
 3581: 0.2001609355211258,
 3645: 0.4486302435398102,
 3684: 1.5606915950775146,
 3714: 0.7297546863555908,
 3778: 0.2119736671447754,
 3798: 0.067564

In [None]:
np.zeros(len(tokenizer.vocab.items()))

array([0., 0., 0., ..., 0., 0., 0.])

In [None]:
min([v for k,v in tokenizer.vocab.items()])

0

In [None]:
weights

[0.3811924457550049,
 0.013468118384480476,
 0.679470956325531,
 0.44665437936782837,
 0.11370912939310074,
 0.024708406999707222,
 0.35583481192588806,
 0.7077890038490295,
 0.07926920056343079,
 0.18895027041435242,
 0.3077377378940582,
 0.1225956529378891,
 0.9020834565162659,
 0.11568041145801544,
 0.3802264332771301,
 0.36217206716537476,
 0.33946168422698975,
 0.9349669814109802,
 0.015637045726180077,
 2.2316195964813232,
 0.2019173800945282,
 0.24346952140331268,
 0.24687933921813965,
 0.7140827178955078,
 0.2980433404445648,
 1.119208812713623,
 0.1564858853816986,
 0.3090040981769562,
 0.579511284828186,
 0.4152419865131378,
 0.5718740820884705,
 0.2001609355211258,
 0.4486302435398102,
 1.5606915950775146,
 0.7297546863555908,
 0.2119736671447754,
 0.06756444275379181,
 0.3686661422252655,
 0.2521461248397827,
 0.6917024254798889,
 0.1785074770450592,
 0.0913749560713768,
 0.24214418232440948,
 0.32159143686294556,
 0.10773749649524689,
 0.17504768073558807,
 0.4820708930492

In [None]:
set(doc) - set([x[0] for x in bow_rep])

{'\n',
 ' ',
 '.',
 'G',
 'I',
 'S',
 'T',
 'a',
 'b',
 'c',
 'd',
 'e',
 'f',
 'g',
 'h',
 'i',
 'j',
 'k',
 'l',
 'm',
 'n',
 'o',
 'p',
 'r',
 's',
 't',
 'v',
 'w',
 'y'}

In [None]:
set([x[0] for x in bow_rep]) - set(doc)

{'##e',
 '##glass',
 '##ink',
 'adjacent',
 'affect',
 'anger',
 'area',
 'areas',
 'bailey',
 'barry',
 'because',
 'bolt',
 'bradford',
 'brake',
 'break',
 'brian',
 'brittle',
 'broken',
 'burn',
 'burst',
 'cause',
 'caused',
 'causes',
 'ceramic',
 'chance',
 'collapse',
 'cool',
 'crack',
 'cracked',
 'cracking',
 'cracks',
 'crash',
 'create',
 'created',
 'crush',
 'damage',
 'determined',
 'disorder',
 'drink',
 'effect',
 'energy',
 'error',
 'eye',
 'factor',
 'factors',
 'fail',
 'failure',
 'fatigue',
 'fireplace',
 'flash',
 'formation',
 'formed',
 'fra',
 'fracture',
 'fragment',
 'frank',
 'friction',
 'fridge',
 'gage',
 'get',
 'gilbert',
 'glass',
 'glasses',
 'governed',
 'gravity',
 'great',
 'hammer',
 'hazard',
 'heat',
 'heated',
 'heating',
 'hot',
 'hotter',
 'hottest',
 'if',
 'impact',
 'injury',
 'interference',
 'issue',
 'knock',
 'leak',
 'level',
 'levels',
 'mechanical',
 'mechanism',
 'might',
 'pan',
 'parker',
 'physics',
 'pressure',
 'problem',


at this point do I pass all the passages through the model to construct the sparse matrix?

Load the toy dataset from the repo  
https://huggingface.co/docs/datasets/loading#csv


In [None]:
!head /content/splade/data/toy_data/triplets/raw.tsv

is a little caffeine ok during pregnancy	We donât know a lot about the effects of caffeine during pregnancy on you and your baby. So itâs best to limit the amount you get each day. If youâre pregnant, limit caffeine to 200 milligrams each day. This is about the amount in 1Â½ 8-ounce cups of coffee or one 12-ounce cup of coffee.	It is generally safe for pregnant women to eat chocolate because studies have shown to prove certain benefits of eating chocolate during pregnancy. However, pregnant women should ensure their caffeine intake is below 200 mg per day.
what fruit is native to australia	Passiflora herbertiana. A rare passion fruit native to Australia. Fruits are green-skinned, white fleshed, with an unknown edible rating. Some sources list the fruit as edible, sweet and tasty, while others list the fruits as being bitter and inedible.assiflora herbertiana. A rare passion fruit native to Australia. Fruits are green-skinned, white fleshed, with an unknown edible rating. Some sou

In [None]:
!pip install datasets

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting datasets
  Downloading datasets-2.9.0-py3-none-any.whl (462 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m462.8/462.8 KB[0m [31m22.0 MB/s[0m eta [36m0:00:00[0m
Collecting responses<0.19
  Downloading responses-0.18.0-py3-none-any.whl (38 kB)
Collecting xxhash
  Downloading xxhash-3.2.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (213 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m213.0/213.0 KB[0m [31m21.6 MB/s[0m eta [36m0:00:00[0m
Collecting multiprocess
  Downloading multiprocess-0.70.14-py38-none-any.whl (132 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m132.0/132.0 KB[0m [31m16.0 MB/s[0m eta [36m0:00:00[0m
Collecting urllib3<1.27,>=1.21.1
  Downloading urllib3-1.26.14-py2.py3-none-any.whl (140 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m140.6/140.6 KB[0m [31m16.2 

In [None]:
# get the toy dataset
from datasets import load_dataset
fp = "/content/splade/data/toy_data/triplets/raw.tsv"
triplets = load_dataset(
    "csv", 
    data_files=fp,
    sep="\t",
    header=None,
    column_names=["query", "positive", "negative"],
    split="train"
)
# create the Expanded/sparse representations with SPLADE



Downloading and preparing dataset csv/default to /root/.cache/huggingface/datasets/csv/default-b08c220f5fefeee1/0.0.0/6b34fb8fcf56f7c8ba51dc895bfa2bfbe43546f190a60fcf74bb5e8afdcc2317...


Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

Dataset csv downloaded and prepared to /root/.cache/huggingface/datasets/csv/default-b08c220f5fefeee1/0.0.0/6b34fb8fcf56f7c8ba51dc895bfa2bfbe43546f190a60fcf74bb5e8afdcc2317. Subsequent calls will reuse this data.


In [None]:
triplets

Dataset({
    features: ['query', 'positive', 'negative'],
    num_rows: 100
})

In [None]:
import datasets
from transformers import pipeline
from transformers.pipelines.pt_utils import KeyDataset
from tqdm.auto import tqdm

#model.config = fm_config
#pipe = pipeline(model=model_type_or_dir)

# KeyDataset (only *pt*) will simply return the item in the dict returned by the dataset item
# as we're not interested in the *target* part of the dataset. For sentence pair use KeyPairDataset
q_array_list = []
for out in tqdm(KeyDataset(triplets, "query")):
    #print(out)
    pt = model(
        d_kwargs=tokenizer(
            out, return_tensors="pt"
            )
        )
    rep = pt["d_rep"].squeeze()
    
    q_array_list.append(rep.detach().numpy())

  0%|          | 0/100 [00:00<?, ?it/s]



In [None]:
c = np.array(rep.detach()).nonzero()[0]
weights = rep[c].cpu().tolist()

In [None]:
combo = list(set(triplets["positive"] + triplets["negative"]))

In [None]:
p_array_list = []
for out in tqdm(triplets["positive"] + triplets["negative"]):
    pt = model(
        d_kwargs=tokenizer(
            out, return_tensors="pt"
            )
        )
    rep = pt["d_rep"].squeeze()
    p_array_list.append(rep.detach().numpy())

  0%|          | 0/200 [00:00<?, ?it/s]

In [None]:
query_matrix = np.matrix(q_array_list)
passage_matrix = np.matrix(p_array_list)

In [None]:
rev_voc = lambda x: reverse_voc[x]
applyall = np.vectorize(rev_voc)
passage_text = []
for i in [np.array(x).squeeze().nonzero()[0] for x in passage_matrix]:
    passage_text.append(' '.join(applyall(i)))

In [None]:
query_text = []
for i in [np.array(x).squeeze().nonzero()[0] for x in query_matrix]:
    query_text.append(' '.join(applyall(i)))

In [None]:
def get_vec(arr):
    z = np.zeros(len(tokenizer.vocab.items()))
    for v in arr:
        z[int(v)] = 1
    return z

In [None]:
passage_matrix

matrix([[0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        ...,
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.],
        [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)

In [None]:
qv = [get_vec(np.array(q).squeeze().nonzero()[0]) for q in query_matrix]
pv = [get_vec(np.array(p).squeeze().nonzero()[0]) for p in passage_matrix]

In [None]:
# get the scores of each passage for each query
q_res = []
for q in qv:
    res = []
    for p in pv:
        score = np.dot(q, p)
        res.append(score)
    q_res.append(np.array(res))
q_mat = np.matrix(q_res)

# Evalute Performance
calculate the MRR@10  
https://machinelearning.wtf/terms/mean-reciprocal-rank-mrr/

In [None]:
from scipy.stats import rankdata

In [None]:
# calculate the mean recipricol rank
rr_sum = 0
MRR_RANK = 10
for i, passage_scores in enumerate(q_mat):
    n = np.array(passage_scores).flatten()
    ranks = rankdata(n, method="min")
    r = len(n) - ranks[i] + 1
    if r > MRR_RANK:  # if the rank is above the threshold, then skip
        continue
    rr_sum += 1 / r
mrr_10 = rr_sum / q_mat.shape[0]
print(f"MRR@10: {mrr_10*100}")

MRR@10: 91.9


In [None]:
# using a pipeline
from transformers import pipeline
from transformers.pipelines.pt_utils import KeyDataset

# import the classifier
MODEL = f"cardiffnlp/tweet-topic-21-multi"
generator = pipeline(model=MODEL)

Downloading:   0%|          | 0.00/1.33k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.84k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/476M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.27k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/780k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/446k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.29M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/239 [00:00<?, ?B/s]

In [None]:
# classify the passages
passage_class_list = []
for out in generator(passage_text, truncation=True, max_length=512):
    passage_class_list.append(out)

In [None]:
query_class_list = []
for out in tqdm(generator(query_text, truncation=True, max_length=512)):
    query_class_list.append(out)

  0%|          | 0/100 [00:00<?, ?it/s]

In [None]:
# get the scores of each passage for each query
q_res = []
for q in qv:
    res = []
    for p in pv:
        score = np.dot(q, p)
        res.append(score)
    q_res.append(np.array(res))
q_mat_class = np.matrix(q_res)

In [None]:
# classify the query
rr_sum = 0
MRR_RANK = 10
for i, passage_scores in enumerate(q_mat):
    n = np.array(passage_scores).flatten()
    # if the class isn't the same, then rank the passage as low
    classification = query_class_list[i]["label"]

    for j, rr in enumerate(n):
        passage_class_classification = passage_class_list[j]["label"]
        # comment out this conditional statement to turn off the filter
        if classification != passage_class_classification:
            n[j] = 0
    
    # query, highest scoring passage, correct passage
    print(
        f"Query: {classification}\n"
        f"Top Passage: {passage_class_list[n.argmax()]['label']}\n"
        f"Expected Passage: {passage_class_list[i]['label']}")
    #print(n)
    ranks = rankdata(n, method="min")
    
    #print(ranks)

    r = len(n) - ranks[i] + 1
    print(f"Rank of Expected: {r}\n")
    if r > MRR_RANK:  # if the rank is above the threshold, then skip
        continue
    rr_sum += 1 / r
mrr_10 = rr_sum / q_mat.shape[0]

print(f"MRR@10: {mrr_10*100}")

Query: diaries_&_daily_life
Top Passage: diaries_&_daily_life
Expected Passage: family
Rank of Expected: 200

Query: food_&_dining
Top Passage: food_&_dining
Expected Passage: food_&_dining
Rank of Expected: 5

Query: news_&_social_concern
Top Passage: news_&_social_concern
Expected Passage: news_&_social_concern
Rank of Expected: 1

Query: travel_&_adventure
Top Passage: travel_&_adventure
Expected Passage: diaries_&_daily_life
Rank of Expected: 200

Query: family
Top Passage: family
Expected Passage: family
Rank of Expected: 1

Query: travel_&_adventure
Top Passage: travel_&_adventure
Expected Passage: diaries_&_daily_life
Rank of Expected: 200

Query: news_&_social_concern
Top Passage: news_&_social_concern
Expected Passage: news_&_social_concern
Rank of Expected: 5

Query: science_&_technology
Top Passage: science_&_technology
Expected Passage: travel_&_adventure
Rank of Expected: 200

Query: fitness_&_health
Top Passage: fitness_&_health
Expected Passage: fitness_&_health
Rank of 

The score filter doesn't improve performance.  
Next step is to create a rank modifier