# SPLADE: Sparse Lexical and Expansion Model for First Stage Ranking

This notebook gives a minimal example usage of SPLADE.

* In this repo, we provide weights for 2 models (in the `weights` folder)
* See [Naver Labs Europe website](https://europe.naverlabs.com/research/machine-learning-and-optimization/splade-models/) for more up-to-date models under various settings
* We also provide two new models via Hugging Face (https://huggingface.co/naver)

| model | MRR@10 (MS MARCO dev) | recall@1000 (MS MARCO dev) | expected FLOPS | ~ avg q length | ~ avg d length | 
| --- | --- | --- | --- | --- | --- |
| `splade_max` (**v2**) | 34.0 | 96.5 | 1.32 | 18 | 92 |
| `distilsplade_max` (**v2**) | 36.8 | 97.9 | 3.82 | 25 | 232 |
| `naver/splade-cocondenser-selfdistil` (**v2bis**, [HF](https://huggingface.co/naver/splade-cocondenser-selfdistil))| 37.6 | 98.4 | 2.32 | 56 | 134 |
| `naver/splade-cocondenser-ensembledistil` (**v2bis**, [HF](https://huggingface.co/naver/splade-cocondenser-ensembledistil)) | 38.3 | 98.3  | 1.85 | 44 | 120 |

In [None]:
!git clone https://github.com/naver/splade.git

Cloning into 'splade'...
remote: Enumerating objects: 514, done.[K
remote: Counting objects: 100% (47/47), done.[K
remote: Compressing objects: 100% (24/24), done.[K
remote: Total 514 (delta 31), reused 23 (delta 23), pack-reused 467[K
Receiving objects: 100% (514/514), 3.07 MiB | 26.63 MiB/s, done.
Resolving deltas: 100% (273/273), done.
Filtering content: 100% (2/2), 511.12 MiB | 12.30 MiB/s, done.


In [None]:
pip install ./splade

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Processing ./splade
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting transformers==4.18.0
  Downloading transformers-4.18.0-py3-none-any.whl (4.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.0/4.0 MB[0m [31m64.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting omegaconf==2.1.2
  Downloading omegaconf-2.1.2-py3-none-any.whl (74 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m74.7/74.7 KB[0m [31m10.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting antlr4-python3-runtime==4.8
  Downloading antlr4-python3-runtime-4.8.tar.gz (112 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m112.4/112.4 KB[0m [31m15.7 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting sacremoses
  Downloading sacremoses-0.0.53.tar.gz (880 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0

In [None]:
import numpy as np

In [None]:
import torch
from transformers import AutoModelForMaskedLM, AutoTokenizer
from splade.models.transformer_rep import Splade

In [None]:
# set the dir for trained weights

##### v2
# model_type_or_dir = "weights/splade_max"
# model_type_or_dir = "weights/distilsplade_max"

### v2bis, directly download from Hugging Face
# model_type_or_dir = "naver/splade-cocondenser-selfdistil"
model_type_or_dir = "naver/splade-cocondenser-ensembledistil"


In [None]:
# loading model and tokenizer
model = Splade(model_type_or_dir, agg="max")
model.eval()
tokenizer = AutoTokenizer.from_pretrained(model_type_or_dir)
reverse_voc = {v: k for k, v in tokenizer.vocab.items()}

Downloading:   0%|          | 0.00/670 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/418M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/466 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/226k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/455k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/112 [00:00<?, ?B/s]

In [None]:
# example document from MS MARCO passage collection (doc_id = 8003157)
doc = """
    Glass and Thermal Stress. 
    Thermal Stress is created 
    when one area of a glass 
    pane gets hotter than an 
    adjacent area. If the stress 
    is too great then the glass 
    will crack. The stress level 
    at which the glass will break 
    is governed by several factors.
    """

In [None]:
# now compute the document representation
with torch.no_grad():
    m = model(
        d_kwargs=tokenizer(
            doc, return_tensors="pt"
            )
        )
    doc_rep = m["d_rep"].squeeze()
    # (sparse) doc rep in voc space, shape (30522,)

# get the number of non-zero dimensions in the rep:
col = torch.nonzero(doc_rep).squeeze().cpu().tolist()
print("number of actual dimensions: ", len(col))

# now let's inspect the bow representation:
weights = doc_rep[col].cpu().tolist()
d = {k: v for k, v in zip(col, weights)}
sorted_d = {k: v for k, v in sorted(d.items(), key=lambda item: item[1], reverse=True)}
bow_rep = []
for k, v in sorted_d.items():
    bow_rep.append((reverse_voc[k], round(v, 2)))
print("SPLADE BOW rep:\n", bow_rep)

number of actual dimensions:  126
SPLADE BOW rep:
 [('stress', 2.25), ('glass', 2.23), ('thermal', 2.18), ('glasses', 1.65), ('pan', 1.62), ('heat', 1.56), ('stressed', 1.42), ('crack', 1.31), ('break', 1.12), ('cracked', 1.1), ('hot', 0.93), ('created', 0.9), ('factors', 0.81), ('broken', 0.73), ('caused', 0.71), ('too', 0.71), ('damage', 0.69), ('if', 0.68), ('hotter', 0.65), ('governed', 0.61), ('heating', 0.59), ('temperature', 0.59), ('adjacent', 0.59), ('cause', 0.58), ('effect', 0.57), ('fracture', 0.56), ('bradford', 0.55), ('strain', 0.53), ('hammer', 0.51), ('brian', 0.48), ('error', 0.47), ('windows', 0.45), ('will', 0.45), ('reaction', 0.42), ('create', 0.42), ('windshield', 0.41), ('heated', 0.41), ('factor', 0.4), ('cracking', 0.39), ('failure', 0.38), ('mechanical', 0.38), ('when', 0.38), ('formed', 0.38), ('bolt', 0.38), ('mechanism', 0.37), ('warm', 0.37), ('areas', 0.36), ('area', 0.36), ('energy', 0.34), ('disorder', 0.33), ('barry', 0.33), ('shock', 0.32), ('determi

Load the toy dataset from the repo  
https://huggingface.co/docs/datasets/loading#csv


In [None]:
!head /content/splade/data/toy_data/triplets/raw.tsv

is a little caffeine ok during pregnancy	We donât know a lot about the effects of caffeine during pregnancy on you and your baby. So itâs best to limit the amount you get each day. If youâre pregnant, limit caffeine to 200 milligrams each day. This is about the amount in 1Â½ 8-ounce cups of coffee or one 12-ounce cup of coffee.	It is generally safe for pregnant women to eat chocolate because studies have shown to prove certain benefits of eating chocolate during pregnancy. However, pregnant women should ensure their caffeine intake is below 200 mg per day.
what fruit is native to australia	Passiflora herbertiana. A rare passion fruit native to Australia. Fruits are green-skinned, white fleshed, with an unknown edible rating. Some sources list the fruit as edible, sweet and tasty, while others list the fruits as being bitter and inedible.assiflora herbertiana. A rare passion fruit native to Australia. Fruits are green-skinned, white fleshed, with an unknown edible rating. Some sou

In [None]:
!pip install datasets

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting datasets
  Downloading datasets-2.9.0-py3-none-any.whl (462 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m462.8/462.8 KB[0m [31m29.9 MB/s[0m eta [36m0:00:00[0m
Collecting multiprocess
  Downloading multiprocess-0.70.14-py38-none-any.whl (132 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m132.0/132.0 KB[0m [31m19.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting xxhash
  Downloading xxhash-3.2.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (213 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m213.0/213.0 KB[0m [31m27.8 MB/s[0m eta [36m0:00:00[0m
Collecting responses<0.19
  Downloading responses-0.18.0-py3-none-any.whl (38 kB)
Collecting urllib3<1.27,>=1.21.1
  Downloading urllib3-1.26.14-py2.py3-none-any.whl (140 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m140.6/140.6 KB[0m [31

In [None]:
# get the toy dataset
from datasets import load_dataset
fp = "/content/splade/data/toy_data/triplets/raw.tsv"
triplets = load_dataset(
    "csv", 
    data_files=fp,
    sep="\t",
    header=None,
    column_names=["query", "positive", "negative"],
    split="train"
)
# create the Expanded/sparse representations with SPLADE



Downloading and preparing dataset csv/default to /root/.cache/huggingface/datasets/csv/default-654802fd692d1316/0.0.0/6b34fb8fcf56f7c8ba51dc895bfa2bfbe43546f190a60fcf74bb5e8afdcc2317...


Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

Dataset csv downloaded and prepared to /root/.cache/huggingface/datasets/csv/default-654802fd692d1316/0.0.0/6b34fb8fcf56f7c8ba51dc895bfa2bfbe43546f190a60fcf74bb5e8afdcc2317. Subsequent calls will reuse this data.


In [None]:
triplets

Dataset({
    features: ['query', 'positive', 'negative'],
    num_rows: 100
})

In [None]:
import datasets
from transformers import pipeline
from transformers.pipelines.pt_utils import KeyDataset
from tqdm.auto import tqdm

#model.config = fm_config
#pipe = pipeline(model=model_type_or_dir)

# KeyDataset (only *pt*) will simply return the item in the dict returned by the dataset item
# as we're not interested in the *target* part of the dataset. For sentence pair use KeyPairDataset
q_array_list = []
for out in tqdm(KeyDataset(triplets, "query")):
    #print(out)
    pt = model(
        d_kwargs=tokenizer(
            out, return_tensors="pt"
            )
        )
    rep = pt["d_rep"].squeeze()

    # get the number of non-zero dimensions in the rep:
    col = torch.nonzero(rep).squeeze().cpu().tolist()
    #print("number of actual dimensions: ", len(col))

    # now let's inspect the bow representation:
    weights = rep[col].cpu().tolist()
    d = {k: v for k, v in zip(col, weights)}
    sorted_d = {k: v for k, v in sorted(d.items(), key=lambda item: item[1], reverse=True)}
    bow_rep = []
    for k, v in sorted_d.items():
        bow_rep.append((k, reverse_voc[k], round(v, 2)))
    q_array_list.append(np.array(bow_rep))

  0%|          | 0/100 [00:00<?, ?it/s]

In [None]:
combo = list(set(triplets["positive"] + triplets["negative"]))

In [None]:
p_array_list = []
for out in tqdm(triplets["positive"]):
    pt = model(
        d_kwargs=tokenizer(
            out, return_tensors="pt"
            )
        )
    rep = pt["d_rep"].squeeze()
    # get the number of non-zero dimensions in the rep:
    col = torch.nonzero(rep).squeeze().cpu().tolist()
    #print("number of actual dimensions: ", len(col))

    # now let's inspect the bow representation:
    weights = rep[col].cpu().tolist()
    d = {k: v for k, v in zip(col, weights)}
    sorted_d = {k: v for k, v in sorted(d.items(), key=lambda item: item[1], reverse=True)}
    bow_rep = []
    for k, v in sorted_d.items():
        bow_rep.append((k, reverse_voc[k], round(v, 2)))
    p_array_list.append(np.array(bow_rep))

  0%|          | 0/100 [00:00<?, ?it/s]

In [None]:
# get the strings
for p in p_array_list:
    string = ' '.join([pp[1] for pp in p])
    print(string, '\n')

pregnancy caf coffee ##fe pregnant baby limit ounce dose mg daily ##ine max during cup child ##ories day babies total drink amount mill cafe cups oz weight effects effect drinking limited diet ##ram grams should dos restriction todd affect limiting much maternity you bottle every soda ##mg harm safe get amounts milk please mom intake 8 ##ception limits supplement consumption breakfast pill restricted avoid ban content addiction 12 everyday children toxic fetal ##zine ##eding coke stomach pre newborn cal women parents carl best having ##tine eat reduce normal ##carriage son sleep take rule help 200 

herbert passion pass ##iana ##lor rare australia ##if edible native ass bitter fruit passionate green ##ible australian white flesh ta bitterness species sweet fruits ##a unknown ##sty skinned variety uncommon whites skin origin is ##ed habitat rating indigenous plant greene carmen culture varieties tree ##ia in color flavor flower berry skins food ##able apple location red citrus taste ##a

In [None]:
def get_vec(arr):
    z = np.zeros(len(tokenizer.vocab.items()))
    for i, j, k in arr:
        np.put(z, i, k)
    return z

In [None]:
qv = [get_vec(q) for q in q_array_list]
pv = [get_vec(p) for p in p_array_list]

In [None]:
# get the scores of each passage for each query
q_res = []
for q in qv:
    res = []
    for p in pv:
        score = np.dot(q, p)
        res.append(score)
    q_res.append(np.array(res))
q_mat = np.matrix(q_res)

# Evalute Performance
calculate the MRR@10  
https://machinelearning.wtf/terms/mean-reciprocal-rank-mrr/

In [None]:
from scipy.stats import rankdata

In [None]:
# calculate the mean recipricol rank
rr_sum = 0
MRR_RANK = 10
for i, passage_scores in enumerate(q_mat):
    n = np.array(passage_scores).flatten()
    ranks = rankdata(n, method="min")
    r = len(n) - ranks[i] + 1
    if r > MRR_RANK:  # if the rank is above the threshold, then skip
        continue
    rr_sum += 1 / r
mrr_10 = rr_sum / q_mat.shape[0]
print(f"MRR@10: {mrr_10*100}")

MRR@10: 99.0


In [None]:
# using a pipeline
from transformers import pipeline
from transformers.pipelines.pt_utils import KeyDataset

# import the classifier
MODEL = f"cardiffnlp/tweet-topic-21-multi"
generator = pipeline(model=MODEL, return_all_scores=True, device=0)

Downloading:   0%|          | 0.00/1.33k [00:00<?, ?B/s]

In [None]:
# classify the passages
passage_class_list = []
passage_full_class_list = []
passage_text = [' '.join(p[1]) for p in p_array_list]
for out in tqdm(generator(passage_text, truncation=True, max_length=512)):
    passage_full_class_list.append(out)
    o = sorted(out, key=lambda x: x["score"])[-1]
    passage_class_list.append(o)

  0%|          | 0/100 [00:00<?, ?it/s]

In [None]:
query_class_list = []
query_full_class_list = []
query_text = [' '.join(p[1]) for p in q_array_list]
for out in tqdm(generator(query_text, truncation=True, max_length=512)):
    query_full_class_list.append(out)
    o = sorted(out, key=lambda x: x["score"])[-1]
    query_class_list.append(o)

  0%|          | 0/100 [00:00<?, ?it/s]

In [None]:
# classify the query
rr_sum = 0
MRR_RANK = 10
for i, passage_scores in enumerate(q_mat):
    n = np.array(passage_scores).flatten()
    # if the class isn't the same, then rank the passage as low
    classification = query_class_list[i]["label"]

    for j, rr in enumerate(n):
        passage_class_classification = passage_class_list[j]["label"]
        if classification != passage_class_classification:
            n[j] = 0
    
    # query, highest scoring passage, correct passage
    print(
        f"Query: {classification}\n"
        f"Top Passage: {passage_class_list[n.argmax()]['label']}\n"
        f"Expected Passage: {passage_class_list[i]['label']}")
    ranks = rankdata(n, method="min")
    
    #print(ranks)

    r = len(n) - ranks[i] + 1
    print(f"Rank of Expected: {r}\n")
    if r > MRR_RANK:  # if the rank is above the threshold, then skip
        continue
    rr_sum += 1 / r
mrr_10 = rr_sum / q_mat.shape[0]

print(f"MRR@10: {mrr_10*100}")

Query: diaries_&_daily_life
Top Passage: diaries_&_daily_life
Expected Passage: food_&_dining
Rank of Expected: 100

Query: diaries_&_daily_life
Top Passage: diaries_&_daily_life
Expected Passage: music
Rank of Expected: 100

Query: travel_&_adventure
Top Passage: food_&_dining
Expected Passage: news_&_social_concern
Rank of Expected: 100

Query: diaries_&_daily_life
Top Passage: diaries_&_daily_life
Expected Passage: diaries_&_daily_life
Rank of Expected: 1

Query: other_hobbies
Top Passage: other_hobbies
Expected Passage: business_&_entrepreneurs
Rank of Expected: 100

Query: diaries_&_daily_life
Top Passage: diaries_&_daily_life
Expected Passage: news_&_social_concern
Rank of Expected: 100

Query: diaries_&_daily_life
Top Passage: diaries_&_daily_life
Expected Passage: news_&_social_concern
Rank of Expected: 100

Query: news_&_social_concern
Top Passage: news_&_social_concern
Expected Passage: news_&_social_concern
Rank of Expected: 1

Query: diaries_&_daily_life
Top Passage: diarie

The score filter doesn't improve performance.  
Next step is to create a rank modifier

In [None]:
# query as multi-topic and passages as single topic
rr_sum = 0
MRR_RANK = 10
for i, passage_scores in enumerate(q_mat):
    n = np.array(passage_scores).flatten()
    # if the class isn't the same, then rank the passage as low
    #print(query_full_class_list[i])
    classification = {}
    for d in query_full_class_list[i]:
        classification[d['label']] = d['score']

    for j, rr in enumerate(n):
        passage_class_classification = passage_class_list[j]["label"]
        n[j]*=classification.get(passage_class_classification, 0)

    # query, highest scoring passage, correct passage
    print(
        f"Query: {query_class_list[i]['label']}\n"
        f"Top Passage: {passage_class_list[n.argmax()]['label']}\n"
        f"Expected Passage: {passage_class_list[i]['label']}")
    #print(n)
    ranks = rankdata(n, method="min")
    
    #print(ranks)

    r = len(n) - ranks[i] + 1
    print(f"Rank of Expected: {r}\n")
    if r > MRR_RANK:  # if the rank is above the threshold, then skip
        continue
    rr_sum += 1 / r
mrr_10 = rr_sum / q_mat.shape[0]

print(f"MRR@10: {mrr_10*100}")

Query: diaries_&_daily_life
Top Passage: diaries_&_daily_life
Expected Passage: food_&_dining
Rank of Expected: 31

Query: diaries_&_daily_life
Top Passage: diaries_&_daily_life
Expected Passage: music
Rank of Expected: 24

Query: travel_&_adventure
Top Passage: news_&_social_concern
Expected Passage: news_&_social_concern
Rank of Expected: 1

Query: diaries_&_daily_life
Top Passage: diaries_&_daily_life
Expected Passage: diaries_&_daily_life
Rank of Expected: 1

Query: other_hobbies
Top Passage: other_hobbies
Expected Passage: business_&_entrepreneurs
Rank of Expected: 6

Query: diaries_&_daily_life
Top Passage: news_&_social_concern
Expected Passage: news_&_social_concern
Rank of Expected: 1

Query: diaries_&_daily_life
Top Passage: diaries_&_daily_life
Expected Passage: news_&_social_concern
Rank of Expected: 6

Query: news_&_social_concern
Top Passage: news_&_social_concern
Expected Passage: news_&_social_concern
Rank of Expected: 1

Query: diaries_&_daily_life
Top Passage: diaries

Modifying the Query Passage Embedding similarity with Topic Similarity based on topic classification probabilities.

In [None]:
def list_dict_to_dict(l: list) -> dict:
    output_dict = {}
    for d in l:
        output_dict[d["label"]] = d["score"]
    return output_dict

def dict_to_vec(d):
    return [y[1] for y in sorted(d.items(), key=lambda x: x[0])]

In [None]:
query_topic_vec_list = [dict_to_vec(list_dict_to_dict(ld)) for ld in query_full_class_list]
passage_topic_vec_list = [dict_to_vec(list_dict_to_dict(ld)) for ld in passage_full_class_list]

In [None]:
# get the scores of each passage for each query topic dot product
qt_res = []
for q in query_topic_vec_list:
    res = []
    for p in passage_topic_vec_list:
        score = np.dot(q, p)
        res.append(score)
    qt_res.append(np.array(res))
qt_mat = np.matrix(qt_res)

In [None]:
# query as multi-topic and passages as single topic
rr_sum = 0
MRR_RANK = 10
for i, passage_scores in enumerate(q_mat):
    n = np.array(passage_scores).flatten()
    nt = np.array(qt_mat[i]).flatten()
    n_mod = n*nt
    
    print(
        f"Query: {query_class_list[i]['label']}\n"
        f"Top Passage: {passage_class_list[n_mod.argmax()]['label']}\n"
        #f"Top Topic Passage: {}\n"
        f"Expected Passage: {passage_class_list[i]['label']}")
    #print(n)
    ranks = rankdata(n_mod, method="min")
    
    #print(ranks)

    r = len(n) - ranks[i] + 1
    print(f"Rank of Expected: {r}\n")
    if r > MRR_RANK:  # if the rank is above the threshold, then skip
        continue
    rr_sum += 1 / r
mrr_10 = rr_sum / q_mat.shape[0]

print(f"MRR@10: {mrr_10*100}")

Query: diaries_&_daily_life
Top Passage: food_&_dining
Expected Passage: food_&_dining
Rank of Expected: 1

Query: diaries_&_daily_life
Top Passage: diaries_&_daily_life
Expected Passage: music
Rank of Expected: 6

Query: travel_&_adventure
Top Passage: news_&_social_concern
Expected Passage: news_&_social_concern
Rank of Expected: 1

Query: diaries_&_daily_life
Top Passage: diaries_&_daily_life
Expected Passage: diaries_&_daily_life
Rank of Expected: 1

Query: other_hobbies
Top Passage: food_&_dining
Expected Passage: business_&_entrepreneurs
Rank of Expected: 2

Query: diaries_&_daily_life
Top Passage: news_&_social_concern
Expected Passage: news_&_social_concern
Rank of Expected: 1

Query: diaries_&_daily_life
Top Passage: diaries_&_daily_life
Expected Passage: news_&_social_concern
Rank of Expected: 2

Query: news_&_social_concern
Top Passage: news_&_social_concern
Expected Passage: news_&_social_concern
Rank of Expected: 1

Query: diaries_&_daily_life
Top Passage: diaries_&_daily_

Create framework to compare results. We want to know if there are certain topics that a new model performs worse on.

WIP  
Try on larger dataset.  
Try on smaller number of topics.

In [None]:
!wget https://msmarco.blob.core.windows.net/msmarcoranking/collectionandqueries.tar.gz

--2023-02-17 04:32:37--  https://msmarco.blob.core.windows.net/msmarcoranking/collectionandqueries.tar.gz
Resolving msmarco.blob.core.windows.net (msmarco.blob.core.windows.net)... 20.150.34.4
Connecting to msmarco.blob.core.windows.net (msmarco.blob.core.windows.net)|20.150.34.4|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1057717952 (1009M) [application/gzip]
Saving to: ‘collectionandqueries.tar.gz.1’


2023-02-17 04:35:54 (5.14 MB/s) - ‘collectionandqueries.tar.gz.1’ saved [1057717952/1057717952]



In [None]:
!tar -xvzf collectionandqueries.tar.gz

collection.tsv
qrels.dev.small.tsv
qrels.train.tsv
queries.dev.small.tsv
queries.dev.tsv
queries.eval.small.tsv
queries.eval.tsv
queries.train.tsv


In [None]:
!head queries.train.tsv

121352	define extreme
634306	what does chattel mean on credit history
920825	what was the great leap forward brainly
510633	tattoo fixers how much does it cost
737889	what is decentralization process.
278900	how many cars enter the la jolla concours d' elegance?
674172	what is a bank transit number
303205	how much can i contribute to nondeductible ira
570009	what are the four major groups of elements
492875	sanitizer temperature


In [None]:
!head qrels.train.tsv

1185869	0	0	1
1185868	0	16	1
597651	0	49	1
403613	0	60	1
1183785	0	389	1
312651	0	616	1
80385	0	723	1
645590	0	944	1
645337	0	1054	1
186154	0	1160	1


In [None]:
import pandas as pd

In [None]:
queries = pd.read_csv("queries.train.tsv", sep='\t', header=None)
queries.columns = ["id", "text"]
queries.set_index("id", inplace=True)

In [None]:
queries.head()

Unnamed: 0_level_0,text
id,Unnamed: 1_level_1
121352,define extreme
634306,what does chattel mean on credit history
920825,what was the great leap forward brainly
510633,tattoo fixers how much does it cost
737889,what is decentralization process.


In [None]:
passages = pd.read_csv("collection.tsv", sep='\t', header=None)
passages.columns = ["id", "text"]

In [None]:
passages.set_index("id", inplace=True)

In [None]:
# qrels
qrels = pd.read_csv("qrels.train.tsv", sep='\t', header=None)
qrels.columns = ["qid", "iteration", "pid", "relevancy"]
#qrels.set_index("id", inplace=True)

In [None]:
qrels

Unnamed: 0,qid,iteration,pid,relevancy
0,1185869,0,0,1
1,1185868,0,16,1
2,597651,0,49,1
3,403613,0,60,1
4,1183785,0,389,1
...,...,...,...,...
532756,19285,0,8841362,1
532757,558837,0,4989159,1
532758,559149,0,8841547,1
532759,706678,0,8841643,1


In [None]:
!pip install datasets

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
# sample from qrels
qrels_sample = qrels.sample(frac=.1)
qrels_joined = qrels_sample.join(queries, on="qid", how="left")
qrels_joined.rename(columns={"text": "query"}, inplace=True)
qrels_joined = qrels_joined.join(passages, on="pid", how="left")
qrels_joined.rename(columns={"text": "passage"}, inplace=True)
qrels_joined.to_csv("all_data_sample.csv")

In [None]:
# precompute the topic classification for passages
from datasets import Dataset
dataset = Dataset.from_pandas(qrels_joined)

In [None]:
from tqdm import tqdm

In [None]:
out_list = []
for out in tqdm(generator(KeyDataset(dataset, "passage"))):
    out_list.append(list_dict_to_dict(out))

100%|██████████| 53276/53276 [09:21<00:00, 94.86it/s]


In [None]:
qrels_joined["passage_class"] = out_list

In [None]:
qrels_joined.to_csv("all_data_sample.csv")

In [None]:
updated_data = dataset.map(
    lambda examples: list_dict_to_dict(generator(examples["passage"])),
    num_proc=4
    )

    



 

#0:   0%|          | 0/13319 [00:00<?, ?ex/s]



 

#1:   0%|          | 0/13319 [00:00<?, ?ex/s]

RuntimeError: ignored